diff --git a/cpp/include/cugraph/algorithms.hpp b/cpp/include/cugraph/algorithms.hpp index f4892cb389c..3bb98ce4150 100644 --- a/cpp/include/cugraph/algorithms.hpp +++ b/cpp/include/cugraph/algorithms.hpp @@ -697,7 +697,9 @@ void flatten_dendrogram(raft::handle_t const& handle, * Supported value : int (signed, 32-bit) * @tparam weight_t Type of edge weights. Supported values : float or double. * - * @param[in] handle Library handle (RAFT). If a communicator is set in the handle, + * @param handle RAFT handle object to encapsulate resources (e.g. CUDA stream, communicator, and + * handles to various CUDA libraries) to run graph algorithms. + * @param rng_state The RngState instance holding pseudo-random number generator state. * @param graph_view Graph view object. * @param edge_weight_view Optional view object holding edge weights for @p graph_view. If @p * edge_weight_view.has_value() == false, edge weights are assumed to be 1.0. @@ -707,6 +709,10 @@ void flatten_dendrogram(raft::handle_t const& handle, * of the communities. Higher resolutions lead to more smaller * communities, lower resolutions lead to fewer larger * communities. (default 1) + * @param[in] theta (optional) The value of the parameter to scale modularity + * gain in Leiden refinement phase. It is used to compute + * the probability of joining a random leiden community. + * Called theta in the Leiden algorithm. * * @return a pair containing: * 1) unique pointer to dendrogram @@ -716,10 +722,12 @@ void flatten_dendrogram(raft::handle_t const& handle, template std::pair>, weight_t> leiden( raft::handle_t const& handle, + raft::random::RngState& rng_state, graph_view_t const& graph_view, std::optional> edge_weight_view, size_t max_level = 100, - weight_t resolution = weight_t{1}); + weight_t resolution = weight_t{1}, + weight_t theta = weight_t{1}); /** * @brief Leiden implementation @@ -741,7 +749,9 @@ std::pair>, weight_t> leiden( * Supported value : int (signed, 32-bit) * @tparam weight_t Type of edge weights. Supported values : float or double. * - * @param[in] handle Library handle (RAFT). If a communicator is set in the handle, + * @param handle RAFT handle object to encapsulate resources (e.g. CUDA stream, communicator, and + * handles to various CUDA libraries) to run graph algorithms. + * @param rng_state The RngState instance holding pseudo-random number generator state. * @param graph_view Graph view object. * @param edge_weight_view Optional view object holding edge weights for @p graph_view. If @p * edge_weight_view.has_value() == false, edge weights are assumed to be 1.0. @@ -751,6 +761,11 @@ std::pair>, weight_t> leiden( * of the communities. Higher resolutions lead to more smaller * communities, lower resolutions lead to fewer larger * communities. (default 1) + * @param[in] theta (optional) The value of the parameter to scale modularity + * gain in Leiden refinement phase. It is used to compute + * the probability of joining a random leiden community. + * Called theta in the Leiden algorithm. + * communities. (default 1) * * @return a pair containing: * 1) number of levels of the returned clustering @@ -759,11 +774,13 @@ std::pair>, weight_t> leiden( template std::pair leiden( raft::handle_t const& handle, + raft::random::RngState& rng_state, graph_view_t const& graph_view, std::optional> edge_weight_view, vertex_t* clustering, // FIXME: Use (device_)span instead size_t max_level = 100, - weight_t resolution = weight_t{1}); + weight_t resolution = weight_t{1}, + weight_t theta = weight_t{1}); /** * @brief Computes the ecg clustering of the given graph. @@ -1992,6 +2009,25 @@ std::tuple, rmm::device_uvector> k_hop_nbr size_t k, bool do_expensive_check = false); +/* + * @brief Find a Maximal Independent Set + * + * @tparam vertex_t Type of vertex identifiers. Needs to be an integral type. + * @tparam edge_t Type of edge identifiers. Needs to be an integral type. + * @tparam multi_gpu Flag indicating whether template instantiation should target single-GPU (false) + * @param handle RAFT handle object to encapsulate resources (e.g. CUDA stream, communicator, and + * handles to various CUDA libraries) to run graph algorithms. + * @param graph_view Graph view object. + * @param rng_state The RngState instance holding pseudo-random number generator state. + * @return A device vector containing vertices found in the maximal independent set + */ + +template +rmm::device_uvector maximal_independent_set( + raft::handle_t const& handle, + graph_view_t const& graph_view, + raft::random::RngState& rng_state); + } // namespace cugraph /** diff --git a/cpp/include/cugraph_c/community_algorithms.h b/cpp/include/cugraph_c/community_algorithms.h index 47d5728880d..fd0e1de9cb4 100644 --- a/cpp/include/cugraph_c/community_algorithms.h +++ b/cpp/include/cugraph_c/community_algorithms.h @@ -19,6 +19,7 @@ #include #include #include +#include #include /** @defgroup community Community algorithms @@ -117,11 +118,16 @@ cugraph_error_code_t cugraph_louvain(const cugraph_resource_handle_t* handle, * @param [in] handle Handle for accessing resources * @param [in] graph Pointer to graph. NOTE: Graph might be modified if the storage * needs to be transposed + * @param [in/out] rng_state State of the random number generator, updated with each call * @param [in] max_level Maximum level in hierarchy * @param [in] resolution Resolution parameter (gamma) in modularity formula. * This changes the size of the communities. Higher resolutions * lead to more smaller communities, lower resolutions lead to * fewer larger communities. + * @param[in] theta (optional) The value of the parameter to scale modularity + * gain in Leiden refinement phase. It is used to compute + * the probability of joining a random leiden community. + * Called theta in the Leiden algorithm. * @param [in] do_expensive_check * A flag to run expensive checks for input arguments (if set to true) * @param [out] result Output from the Leiden call @@ -130,9 +136,11 @@ cugraph_error_code_t cugraph_louvain(const cugraph_resource_handle_t* handle, * @return error code */ cugraph_error_code_t cugraph_leiden(const cugraph_resource_handle_t* handle, + cugraph_rng_state_t* rng_state, cugraph_graph_t* graph, size_t max_level, double resolution, + double theta, bool_t do_expensive_check, cugraph_hierarchical_clustering_result_t** result, cugraph_error_t** error); diff --git a/cpp/src/c_api/leiden.cpp b/cpp/src/c_api/leiden.cpp index 5a82321427f..7b1ca10545c 100644 --- a/cpp/src/c_api/leiden.cpp +++ b/cpp/src/c_api/leiden.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include @@ -28,25 +29,30 @@ #include #include +#include + #include namespace { struct leiden_functor : public cugraph::c_api::abstract_functor { raft::handle_t const& handle_; - cugraph::c_api::cugraph_graph_t* graph_; + cugraph::c_api::cugraph_rng_state_t* rng_state_{nullptr}; + cugraph::c_api::cugraph_graph_t* graph_{nullptr}; size_t max_level_; double resolution_; bool do_expensive_check_; cugraph::c_api::cugraph_hierarchical_clustering_result_t* result_{}; leiden_functor(::cugraph_resource_handle_t const* handle, + cugraph_rng_state_t* rng_state, ::cugraph_graph_t* graph, size_t max_level, double resolution, bool do_expensive_check) : abstract_functor(), handle_(*reinterpret_cast(handle)->handle_), + rng_state_(reinterpret_cast(rng_state)), graph_(reinterpret_cast(graph)), max_level_(max_level), resolution_(resolution), @@ -64,10 +70,6 @@ struct leiden_functor : public cugraph::c_api::abstract_functor { { if constexpr (!cugraph::is_candidate::value) { unsupported(); - } else if constexpr (multi_gpu) { - error_code_ = CUGRAPH_NOT_IMPLEMENTED; - error_->error_message_ = "leiden not currently implemented for multi-GPU"; - } else { // leiden expects store_transposed == false if constexpr (store_transposed) { @@ -98,6 +100,7 @@ struct leiden_functor : public cugraph::c_api::abstract_functor { // coarsened graphs. auto [level, modularity] = cugraph::leiden(handle_, + rng_state_->rng_state_, graph_view, (edge_weights != nullptr) ? std::make_optional(edge_weights->view()) @@ -123,14 +126,16 @@ struct leiden_functor : public cugraph::c_api::abstract_functor { } // namespace extern "C" cugraph_error_code_t cugraph_leiden(const cugraph_resource_handle_t* handle, + cugraph_rng_state_t* rng_state, cugraph_graph_t* graph, size_t max_level, double resolution, + double theta, bool_t do_expensive_check, cugraph_hierarchical_clustering_result_t** result, cugraph_error_t** error) { - leiden_functor functor(handle, graph, max_level, resolution, do_expensive_check); + leiden_functor functor(handle, rng_state, graph, max_level, resolution, do_expensive_check); return cugraph::c_api::run_algorithm(graph, functor, result, error); } diff --git a/cpp/src/community/detail/mis_impl.cuh b/cpp/src/community/detail/mis_impl.cuh index c09da35f711..bcd71af5a08 100644 --- a/cpp/src/community/detail/mis_impl.cuh +++ b/cpp/src/community/detail/mis_impl.cuh @@ -16,48 +16,38 @@ */ #pragma once +#include #include #include #include -#include #include #include #include #include #include -#include -#include -#include - #include #include -#include -#include -#include #include +#include #include #include -#include -#include +#include #include #include #include -#include -#include -#include namespace cugraph { namespace detail { -template -rmm::device_uvector compute_mis( +template +rmm::device_uvector maximal_independent_set( raft::handle_t const& handle, cugraph::graph_view_t const& graph_view, - std::optional> edge_weight_view) + raft::random::RngState& rng_state) { using GraphViewType = cugraph::graph_view_t; @@ -88,25 +78,17 @@ rmm::device_uvector compute_mis( thrust::copy(handle.get_thrust_policy(), vertex_begin, vertex_end, ranks.begin()); // Set ranks of zero out-degree vetices to std::numeric_limits::lowest() - thrust::for_each( + thrust::transform_if( handle.get_thrust_policy(), - vertex_begin, - vertex_end, - [out_degrees = raft::device_span(out_degrees.data(), out_degrees.size()), - ranks = raft::device_span(ranks.data(), ranks.size()), - v_first = graph_view.local_vertex_partition_range_first()] __device__(auto v) { - auto v_offset = v - v_first; - if (out_degrees[v_offset] == 0) { ranks[v_offset] = std::numeric_limits::lowest(); } - }); + out_degrees.begin(), + out_degrees.end(), + ranks.begin(), + [] __device__(auto) { return std::numeric_limits::lowest(); }, + [] __device__(auto deg) { return deg == 0; }); out_degrees.resize(0, handle.get_stream()); out_degrees.shrink_to_fit(handle.get_stream()); - thrust::default_random_engine g; - size_t seed = 0; - if constexpr (multi_gpu) { seed = handle.get_comms().get_rank(); } - g.seed(seed); - size_t loop_counter = 0; while (true) { loop_counter++; @@ -117,22 +99,48 @@ rmm::device_uvector compute_mis( thrust::copy(handle.get_thrust_policy(), ranks.begin(), ranks.end(), temporary_ranks.begin()); // Select a random set of candidate vertices - // FIXME: use common utility function to select a subset of remaining vertices - // and for MG extension, select from disributed array remaining vertices - thrust::shuffle( - handle.get_thrust_policy(), remaining_vertices.begin(), remaining_vertices.end(), g); - vertex_t nr_candidates = - (remaining_vertices.size() < 1024) - ? remaining_vertices.size() - : std::min(static_cast((0.50 + 0.25 * loop_counter) * remaining_vertices.size()), - static_cast(remaining_vertices.size())); + vertex_t nr_remaining_vertices_to_check = remaining_vertices.size(); + if (multi_gpu) { + nr_remaining_vertices_to_check = host_scalar_allreduce(handle.get_comms(), + nr_remaining_vertices_to_check, + raft::comms::op_t::SUM, + handle.get_stream()); + } + + vertex_t nr_candidates = (nr_remaining_vertices_to_check < 1024) + ? nr_remaining_vertices_to_check + : std::min(static_cast((0.50 + 0.25 * loop_counter) * + nr_remaining_vertices_to_check), + nr_remaining_vertices_to_check); + + // FIXME: Can we improve performance here? + // FIXME: if(nr_remaining_vertices_to_check < 1024), may avoid calling select_random_vertices + auto d_sampled_vertices = + cugraph::select_random_vertices(handle, + graph_view, + std::make_optional(raft::device_span{ + remaining_vertices.data(), remaining_vertices.size()}), + rng_state, + nr_candidates, + false, + true); + + rmm::device_uvector non_candidate_vertices( + remaining_vertices.size() - d_sampled_vertices.size(), handle.get_stream()); + + thrust::set_difference(handle.get_thrust_policy(), + remaining_vertices.begin(), + remaining_vertices.end(), + d_sampled_vertices.begin(), + d_sampled_vertices.end(), + non_candidate_vertices.begin()); // Set temporary ranks of non-candidate vertices to std::numeric_limits::lowest() thrust::for_each( handle.get_thrust_policy(), - remaining_vertices.begin(), - remaining_vertices.end() - nr_candidates, + non_candidate_vertices.begin(), + non_candidate_vertices.end(), [temporary_ranks = raft::device_span(temporary_ranks.data(), temporary_ranks.size()), v_first = graph_view.local_vertex_partition_range_first()] __device__(auto v) { @@ -160,7 +168,6 @@ rmm::device_uvector compute_mis( // // Find maximum rank outgoing neighbor for each vertex - // (In case of Leiden decision graph, each vertex has at most one outgoing edge) // rmm::device_uvector max_outgoing_ranks(local_vtx_partitoin_size, handle.get_stream()); @@ -224,8 +231,8 @@ rmm::device_uvector compute_mis( // auto last = thrust::remove_if( handle.get_thrust_policy(), - remaining_vertices.end() - nr_candidates, - remaining_vertices.end(), + d_sampled_vertices.begin(), + d_sampled_vertices.end(), [max_rank_neighbor_first = max_outgoing_ranks.begin(), ranks = raft::device_span(ranks.data(), ranks.size()), v_first = graph_view.local_vertex_partition_range_first()] __device__(auto v) { @@ -252,11 +259,23 @@ rmm::device_uvector compute_mis( max_outgoing_ranks.resize(0, handle.get_stream()); max_outgoing_ranks.shrink_to_fit(handle.get_stream()); - remaining_vertices.resize(thrust::distance(remaining_vertices.begin(), last), + d_sampled_vertices.resize(thrust::distance(d_sampled_vertices.begin(), last), + handle.get_stream()); + d_sampled_vertices.shrink_to_fit(handle.get_stream()); + + remaining_vertices.resize(non_candidate_vertices.size() + d_sampled_vertices.size(), handle.get_stream()); remaining_vertices.shrink_to_fit(handle.get_stream()); - vertex_t nr_remaining_vertices_to_check = remaining_vertices.size(); + // merge non-candidate and remaining candidate vertices + thrust::merge(handle.get_thrust_policy(), + non_candidate_vertices.begin(), + non_candidate_vertices.end(), + d_sampled_vertices.begin(), + d_sampled_vertices.end(), + remaining_vertices.begin()); + + nr_remaining_vertices_to_check = remaining_vertices.size(); if (multi_gpu) { nr_remaining_vertices_to_check = host_scalar_allreduce(handle.get_comms(), nr_remaining_vertices_to_check, @@ -289,4 +308,14 @@ rmm::device_uvector compute_mis( return mis; } } // namespace detail + +template +rmm::device_uvector maximal_independent_set( + raft::handle_t const& handle, + graph_view_t const& graph_view, + raft::random::RngState& rng_state) +{ + return detail::maximal_independent_set(handle, graph_view, rng_state); +} + } // namespace cugraph diff --git a/cpp/src/community/detail/mis_mg.cu b/cpp/src/community/detail/mis_mg.cu index def60f698ee..8ff0ed4b395 100644 --- a/cpp/src/community/detail/mis_mg.cu +++ b/cpp/src/community/detail/mis_mg.cu @@ -16,36 +16,19 @@ #include namespace cugraph { -namespace detail { -template rmm::device_uvector compute_mis( +template rmm::device_uvector maximal_independent_set( raft::handle_t const& handle, graph_view_t const& decision_graph_view, - std::optional> edge_weight_view); + raft::random::RngState& rng_state); -template rmm::device_uvector compute_mis( - raft::handle_t const& handle, - graph_view_t const& decision_graph_view, - std::optional> edge_weight_view); - -template rmm::device_uvector compute_mis( +template rmm::device_uvector maximal_independent_set( raft::handle_t const& handle, graph_view_t const& decision_graph_view, - std::optional> edge_weight_view); - -template rmm::device_uvector compute_mis( - raft::handle_t const& handle, - graph_view_t const& decision_graph_view, - std::optional> edge_weight_view); - -template rmm::device_uvector compute_mis( - raft::handle_t const& handle, - graph_view_t const& decision_graph_view, - std::optional> edge_weight_view); + raft::random::RngState& rng_state); -template rmm::device_uvector compute_mis( +template rmm::device_uvector maximal_independent_set( raft::handle_t const& handle, graph_view_t const& decision_graph_view, - std::optional> edge_weight_view); + raft::random::RngState& rng_state); -} // namespace detail } // namespace cugraph diff --git a/cpp/src/community/detail/mis_sg.cu b/cpp/src/community/detail/mis_sg.cu index 4da2b4ea741..d1012ae17bb 100644 --- a/cpp/src/community/detail/mis_sg.cu +++ b/cpp/src/community/detail/mis_sg.cu @@ -16,36 +16,19 @@ #include namespace cugraph { -namespace detail { -template rmm::device_uvector compute_mis( +template rmm::device_uvector maximal_independent_set( raft::handle_t const& handle, graph_view_t const& decision_graph_view, - std::optional> edge_weight_view); + raft::random::RngState& rng_state); -template rmm::device_uvector compute_mis( - raft::handle_t const& handle, - graph_view_t const& decision_graph_view, - std::optional> edge_weight_view); - -template rmm::device_uvector compute_mis( +template rmm::device_uvector maximal_independent_set( raft::handle_t const& handle, graph_view_t const& decision_graph_view, - std::optional> edge_weight_view); - -template rmm::device_uvector compute_mis( - raft::handle_t const& handle, - graph_view_t const& decision_graph_view, - std::optional> edge_weight_view); - -template rmm::device_uvector compute_mis( - raft::handle_t const& handle, - graph_view_t const& decision_graph_view, - std::optional> edge_weight_view); + raft::random::RngState& rng_state); -template rmm::device_uvector compute_mis( +template rmm::device_uvector maximal_independent_set( raft::handle_t const& handle, graph_view_t const& decision_graph_view, - std::optional> edge_weight_view); + raft::random::RngState& rng_state); -} // namespace detail } // namespace cugraph diff --git a/cpp/src/community/detail/refine.hpp b/cpp/src/community/detail/refine.hpp index 0dd069645f3..69b6702edf8 100644 --- a/cpp/src/community/detail/refine.hpp +++ b/cpp/src/community/detail/refine.hpp @@ -20,6 +20,7 @@ #include #include +#include #include namespace cugraph { @@ -31,11 +32,13 @@ std::tuple, rmm::device_uvector>> refine_clustering( raft::handle_t const& handle, + raft::random::RngState& rng_state, graph_view_t const& graph_view, std::optional> edge_weight_view, weight_t total_edge_weight, weight_t resolution, + weight_t theta, rmm::device_uvector const& vertex_weights_v, rmm::device_uvector&& cluster_keys_v, rmm::device_uvector&& cluster_weights_v, diff --git a/cpp/src/community/detail/refine_impl.cuh b/cpp/src/community/detail/refine_impl.cuh index 2976a83773e..bbd720131de 100644 --- a/cpp/src/community/detail/refine_impl.cuh +++ b/cpp/src/community/detail/refine_impl.cuh @@ -16,7 +16,7 @@ #pragma once #include -#include +#include #include #include #include @@ -29,12 +29,15 @@ #include #include +#include + #include #include #include #include #include #include +#include #include #include #include @@ -43,9 +46,6 @@ #include #include -#include -#include - CUCO_DECLARE_BITWISE_COMPARABLE(float) CUCO_DECLARE_BITWISE_COMPARABLE(double) @@ -57,7 +57,9 @@ namespace detail { template struct leiden_key_aggregated_edge_op_t { weight_t total_edge_weight{}; - weight_t gamma{}; + weight_t resolution{}; // resolution parameter + weight_t theta{}; // scaling factor + raft::random::DeviceState device_state{}; __device__ auto operator()( vertex_t src, vertex_t neighboring_leiden_cluster, @@ -83,22 +85,37 @@ struct leiden_key_aggregated_edge_op_t { // E(Cr, S-Cr) > ||Cr||*(||S|| -||Cr||) bool is_dst_leiden_cluster_well_connected = dst_leiden_cut_to_louvain > - gamma * dst_leiden_volume * (louvain_cluster_volume - dst_leiden_volume); + resolution * dst_leiden_volume * (louvain_cluster_volume - dst_leiden_volume); // E(v, Cr-v) - ||v||* ||Cr-v||/||V(G)|| // aggregated_weight_to_neighboring_leiden_cluster == E(v, Cr-v)? - weight_t theta = -1.0; - // if ((is_src_active > 0) && is_src_well_connected) { + weight_t mod_gain = -1.0; if (is_src_active > 0) { if ((louvain_of_dst_leiden_cluster == src_louvain_cluster) && is_dst_leiden_cluster_well_connected) { - theta = aggregated_weight_to_neighboring_leiden_cluster - - gamma * src_weighted_deg * dst_leiden_volume / total_edge_weight; + mod_gain = aggregated_weight_to_neighboring_leiden_cluster - + resolution * src_weighted_deg * (dst_leiden_volume - src_weighted_deg) / + total_edge_weight; + + weight_t random_number{0.0}; + if (mod_gain > 0.0) { + auto flat_id = uint64_t{threadIdx.x + blockIdx.x * blockDim.x}; + raft::random::PCGenerator gen(device_state, flat_id); + raft::random::UniformDistParams int_params{}; + int_params.start = weight_t{0.0}; + int_params.end = weight_t{1.0}; + raft::random::custom_next(gen, &random_number, int_params, 0, 0); + } + + mod_gain = mod_gain > 0.0 + ? __expf(static_cast((2.0 * mod_gain) / (theta * total_edge_weight))) * + random_number + : -1.0; } } - return thrust::make_tuple(theta, neighboring_leiden_cluster); + return thrust::make_tuple(mod_gain, neighboring_leiden_cluster); } }; @@ -108,11 +125,13 @@ std::tuple, rmm::device_uvector>> refine_clustering( raft::handle_t const& handle, + raft::random::RngState& rng_state, GraphViewType const& graph_view, std::optional> edge_weight_view, weight_t total_edge_weight, weight_t resolution, + weight_t theta, rmm::device_uvector const& weighted_degree_of_vertices, rmm::device_uvector&& louvain_cluster_keys, rmm::device_uvector&& louvain_cluster_weights, @@ -216,11 +235,11 @@ refine_clustering( wcut_deg_and_cluster_vol_triple_begin, wcut_deg_and_cluster_vol_triple_end, singleton_and_connected_flags.begin(), - [gamma = resolution] __device__(auto wcut_wdeg_and_louvain_volume) { + [resolution] __device__(auto wcut_wdeg_and_louvain_volume) { auto wcut = thrust::get<0>(wcut_wdeg_and_louvain_volume); auto wdeg = thrust::get<1>(wcut_wdeg_and_louvain_volume); auto louvain_volume = thrust::get<2>(wcut_wdeg_and_louvain_volume); - return wcut > (gamma * wdeg * (louvain_volume - wdeg)); + return wcut > (resolution * wdeg * (louvain_volume - wdeg)); }); edge_src_property_t src_louvain_cluster_weight_cache(handle); @@ -352,7 +371,7 @@ refine_clustering( thrust::tuple src_louvain_leidn, thrust::tuple dst_louvain_leiden, auto wt) { - weight_t refined_partition_volume_contribution{0}; + weight_t refined_partition_volume_contribution{wt}; weight_t refined_partition_cut_contribution{0}; auto src_louvain = thrust::get<0>(src_louvain_leidn); @@ -362,11 +381,7 @@ refine_clustering( auto dst_leiden = thrust::get<1>(dst_louvain_leiden); if (src_louvain == dst_louvain) { - if (src_leiden == dst_leiden) { - refined_partition_volume_contribution = wt; - } else { - refined_partition_cut_contribution = wt; - } + if (src_leiden != dst_leiden) { refined_partition_cut_contribution = wt; } } return thrust::make_tuple(refined_partition_volume_contribution, refined_partition_cut_contribution); @@ -407,11 +422,49 @@ refine_clustering( louvain_assignment_of_vertices.data())); rmm::device_uvector louvain_of_leiden_keys_used_in_edge_reduction( - leiden_keys_used_in_edge_reduction.size(), handle.get_stream()); - leiden_to_louvain_map.view().find(leiden_keys_used_in_edge_reduction.begin(), - leiden_keys_used_in_edge_reduction.end(), - louvain_of_leiden_keys_used_in_edge_reduction.begin(), - handle.get_stream()); + 0, handle.get_stream()); + + if (GraphViewType::is_multi_gpu) { + auto& comm = handle.get_comms(); + auto const comm_size = comm.get_size(); + auto& major_comm = handle.get_subcomm(cugraph::partition_manager::major_comm_name()); + auto const major_comm_size = major_comm.get_size(); + auto& minor_comm = handle.get_subcomm(cugraph::partition_manager::minor_comm_name()); + auto const minor_comm_size = minor_comm.get_size(); + + auto partitions_range_lasts = graph_view.vertex_partition_range_lasts(); + rmm::device_uvector d_partitions_range_lasts(partitions_range_lasts.size(), + handle.get_stream()); + + raft::update_device(d_partitions_range_lasts.data(), + partitions_range_lasts.data(), + partitions_range_lasts.size(), + handle.get_stream()); + + cugraph::detail::compute_gpu_id_from_int_vertex_t vertex_to_gpu_id_op{ + raft::device_span(d_partitions_range_lasts.data(), + d_partitions_range_lasts.size()), + major_comm_size, + minor_comm_size}; + + // cugraph::detail::compute_gpu_id_from_ext_vertex_t vertex_to_gpu_id_op{ + // comm_size, major_comm_size, minor_comm_size}; + + louvain_of_leiden_keys_used_in_edge_reduction = + cugraph::collect_values_for_keys(handle, + leiden_to_louvain_map.view(), + leiden_keys_used_in_edge_reduction.begin(), + leiden_keys_used_in_edge_reduction.end(), + vertex_to_gpu_id_op); + } else { + louvain_of_leiden_keys_used_in_edge_reduction.resize( + leiden_keys_used_in_edge_reduction.size(), handle.get_stream()); + + leiden_to_louvain_map.view().find(leiden_keys_used_in_edge_reduction.begin(), + leiden_keys_used_in_edge_reduction.end(), + louvain_of_leiden_keys_used_in_edge_reduction.begin(), + handle.get_stream()); + } // ||Cr|| //f(Cr) // E(Cr, louvain(v) - Cr) //f(Cr) @@ -438,6 +491,9 @@ refine_clustering( // // Decide best/positive move for each vertex // + unsigned seed = std::chrono::system_clock::now().time_since_epoch().count(); + raft::random::RngState rng_state(seed); + raft::random::DeviceState device_state(rng_state); auto gain_and_dst_output_pairs = allocate_dataframe_buffer>( graph_view.local_vertex_partition_range_size(), handle.get_stream()); @@ -451,8 +507,8 @@ refine_clustering( : detail::edge_minor_property_view_t( leiden_assignment.data(), vertex_t{0}), leiden_cluster_key_values_map.view(), - detail::leiden_key_aggregated_edge_op_t{total_edge_weight, - resolution}, + detail::leiden_key_aggregated_edge_op_t{ + total_edge_weight, resolution, theta, device_state}, thrust::make_tuple(weight_t{0}, vertex_t{-1}), reduce_op::maximum>(), cugraph::get_dataframe_buffer_begin(gain_and_dst_output_pairs)); @@ -483,16 +539,6 @@ refine_clustering( auto vertex_end = thrust::make_counting_iterator(graph_view.local_vertex_partition_range_last()); - // edge (src, dst, gain) - auto edge_begin = thrust::make_zip_iterator( - thrust::make_tuple(vertex_begin, - thrust::get<1>(gain_and_dst_first.get_iterator_tuple()), - thrust::get<0>(gain_and_dst_first.get_iterator_tuple()))); - auto edge_end = thrust::make_zip_iterator( - thrust::make_tuple(vertex_end, - thrust::get<1>(gain_and_dst_last.get_iterator_tuple()), - thrust::get<0>(gain_and_dst_last.get_iterator_tuple()))); - // // Filter out moves with -ve gains // @@ -501,17 +547,18 @@ refine_clustering( gain_and_dst_first, gain_and_dst_last, [] __device__(auto gain_dst_pair) { - weight_t gain = thrust::get<0>(gain_dst_pair); vertex_t dst = thrust::get<1>(gain_dst_pair); + weight_t gain = thrust::get<0>(gain_dst_pair); return (gain > POSITIVE_GAIN) && (dst >= 0); }); + vertex_t total_nr_valid_tuples = nr_valid_tuples; if (GraphViewType::is_multi_gpu) { - nr_valid_tuples = host_scalar_allreduce( - handle.get_comms(), nr_valid_tuples, raft::comms::op_t::SUM, handle.get_stream()); + total_nr_valid_tuples = host_scalar_allreduce( + handle.get_comms(), total_nr_valid_tuples, raft::comms::op_t::SUM, handle.get_stream()); } - if (nr_valid_tuples == 0) { + if (total_nr_valid_tuples == 0) { cugraph::resize_dataframe_buffer(gain_and_dst_output_pairs, 0, handle.get_stream()); cugraph::shrink_to_fit_dataframe_buffer(gain_and_dst_output_pairs, handle.get_stream()); break; @@ -525,6 +572,16 @@ refine_clustering( auto d_src_dst_gain_iterator = thrust::make_zip_iterator( thrust::make_tuple(d_srcs.begin(), d_dsts.begin(), (*d_weights).begin())); + // edge (src, dst, gain) + auto edge_begin = thrust::make_zip_iterator( + thrust::make_tuple(vertex_begin, + thrust::get<1>(gain_and_dst_first.get_iterator_tuple()), + thrust::get<0>(gain_and_dst_first.get_iterator_tuple()))); + auto edge_end = thrust::make_zip_iterator( + thrust::make_tuple(vertex_end, + thrust::get<1>(gain_and_dst_last.get_iterator_tuple()), + thrust::get<0>(gain_and_dst_last.get_iterator_tuple()))); + thrust::copy_if(handle.get_thrust_policy(), edge_begin, edge_end, @@ -540,23 +597,41 @@ refine_clustering( // // Create decision graph from edgelist // - constexpr bool storage_transposed = false; - constexpr bool multi_gpu = GraphViewType::is_multi_gpu; - using DecisionGraphViewType = cugraph::graph_view_t; + constexpr bool store_transposed = false; + constexpr bool multi_gpu = GraphViewType::is_multi_gpu; + using DecisionGraphViewType = cugraph::graph_view_t; - cugraph::graph_t decision_graph(handle); + cugraph::graph_t decision_graph(handle); std::optional> renumber_map{std::nullopt}; std::optional> coarse_edge_weights{ std::nullopt}; + if constexpr (multi_gpu) { + std::tie(store_transposed ? d_dsts : d_srcs, + store_transposed ? d_srcs : d_dsts, + d_weights, + std::ignore, + std::ignore) = + cugraph::detail::shuffle_ext_vertex_pairs_with_values_to_local_gpu_by_edge_partitioning< + vertex_t, + vertex_t, + weight_t, + int32_t>(handle, + store_transposed ? std::move(d_dsts) : std::move(d_srcs), + store_transposed ? std::move(d_srcs) : std::move(d_dsts), + std::move(d_weights), + std::nullopt, + std::nullopt); + } + std::tie(decision_graph, coarse_edge_weights, std::ignore, std::ignore, renumber_map) = create_graph_from_edgelist(handle, std::nullopt, std::move(d_srcs), @@ -565,7 +640,8 @@ refine_clustering( std::nullopt, std::nullopt, cugraph::graph_properties_t{false, false}, - true); + true, + false); auto decision_graph_view = decision_graph.view(); @@ -573,10 +649,8 @@ refine_clustering( // Determine a set of moves using MIS of the decision_graph // - auto vertices_in_mis = compute_mis( - handle, - decision_graph_view, - coarse_edge_weights ? std::make_optional(coarse_edge_weights->view()) : std::nullopt); + auto vertices_in_mis = + maximal_independent_set(handle, decision_graph_view, rng_state); rmm::device_uvector numbering_indices((*renumber_map).size(), handle.get_stream()); detail::sequence_fill(handle.get_stream(), @@ -602,6 +676,11 @@ refine_clustering( (*renumber_map).resize(0, handle.get_stream()); (*renumber_map).shrink_to_fit(handle.get_stream()); + if (GraphViewType::is_multi_gpu) { + vertices_in_mis = cugraph::detail::shuffle_int_vertices_to_local_gpu_by_vertex_partitioning( + handle, std::move(vertices_in_mis), graph_view.vertex_partition_range_lasts()); + } + // // Mark the chosen vertices as non-singleton and update their leiden cluster to dst // @@ -650,9 +729,10 @@ refine_clustering( thrust::unique(handle.get_thrust_policy(), dst_vertices.begin(), dst_vertices.end()))), handle.get_stream()); + // Shuffle dst vertices to owner GPU, according to vetex partitioning if constexpr (GraphViewType::is_multi_gpu) { - dst_vertices = - shuffle_ext_vertices_to_local_gpu_by_vertex_partitioning(handle, std::move(dst_vertices)); + dst_vertices = cugraph::detail::shuffle_int_vertices_to_local_gpu_by_vertex_partitioning( + handle, std::move(dst_vertices), graph_view.vertex_partition_range_lasts()); thrust::sort(handle.get_thrust_policy(), dst_vertices.begin(), dst_vertices.end()); @@ -682,9 +762,6 @@ refine_clustering( src_louvain_cluster_weight_cache.clear(handle); src_cut_to_louvain_cache.clear(handle); - louvain_assignment_of_vertices.resize(0, handle.get_stream()); - louvain_assignment_of_vertices.shrink_to_fit(handle.get_stream()); - singleton_and_connected_flags.resize(0, handle.get_stream()); singleton_and_connected_flags.shrink_to_fit(handle.get_stream()); vertex_louvain_cluster_weights.resize(0, handle.get_stream()); @@ -716,9 +793,13 @@ refine_clustering( leiden_keys_to_read_louvain.resize(nr_unique_leiden_clusters, handle.get_stream()); if constexpr (GraphViewType::is_multi_gpu) { + // leiden_keys_to_read_louvain = + // cugraph::detail::shuffle_ext_vertices_to_local_gpu_by_vertex_partitioning( + // handle, std::move(leiden_keys_to_read_louvain)); + leiden_keys_to_read_louvain = - cugraph::detail::shuffle_ext_vertices_to_local_gpu_by_vertex_partitioning( - handle, std::move(leiden_keys_to_read_louvain)); + cugraph::detail::shuffle_int_vertices_to_local_gpu_by_vertex_partitioning( + handle, std::move(leiden_keys_to_read_louvain), graph_view.vertex_partition_range_lasts()); thrust::sort(handle.get_thrust_policy(), leiden_keys_to_read_louvain.begin(), @@ -742,8 +823,23 @@ refine_clustering( auto& minor_comm = handle.get_subcomm(cugraph::partition_manager::minor_comm_name()); auto const minor_comm_size = minor_comm.get_size(); - cugraph::detail::compute_gpu_id_from_ext_vertex_t vertex_to_gpu_id_op{ - comm_size, major_comm_size, minor_comm_size}; + auto partitions_range_lasts = graph_view.vertex_partition_range_lasts(); + rmm::device_uvector d_partitions_range_lasts(partitions_range_lasts.size(), + handle.get_stream()); + + raft::update_device(d_partitions_range_lasts.data(), + partitions_range_lasts.data(), + partitions_range_lasts.size(), + handle.get_stream()); + + cugraph::detail::compute_gpu_id_from_int_vertex_t vertex_to_gpu_id_op{ + raft::device_span(d_partitions_range_lasts.data(), + d_partitions_range_lasts.size()), + major_comm_size, + minor_comm_size}; + + // cugraph::detail::compute_gpu_id_from_ext_vertex_t vertex_to_gpu_id_op{ + // comm_size, major_comm_size, minor_comm_size}; lovain_of_leiden_cluster_keys = cugraph::collect_values_for_keys(handle, @@ -751,6 +847,7 @@ refine_clustering( leiden_keys_to_read_louvain.begin(), leiden_keys_to_read_louvain.end(), vertex_to_gpu_id_op); + } else { lovain_of_leiden_cluster_keys.resize(leiden_keys_to_read_louvain.size(), handle.get_stream()); diff --git a/cpp/src/community/detail/refine_mg.cu b/cpp/src/community/detail/refine_mg.cu index 570298126bf..85b4a150e84 100644 --- a/cpp/src/community/detail/refine_mg.cu +++ b/cpp/src/community/detail/refine_mg.cu @@ -22,10 +22,12 @@ template std::tuple, std::pair, rmm::device_uvector>> refine_clustering( raft::handle_t const& handle, + raft::random::RngState& rng_state, cugraph::graph_view_t const& graph_view, std::optional> edge_weight_view, float total_edge_weight, float resolution, + float theta, rmm::device_uvector const& vertex_weights_v, rmm::device_uvector&& cluster_keys_v, rmm::device_uvector&& cluster_weights_v, @@ -42,10 +44,12 @@ template std::tuple, std::pair, rmm::device_uvector>> refine_clustering( raft::handle_t const& handle, + raft::random::RngState& rng_state, cugraph::graph_view_t const& graph_view, std::optional> edge_weight_view, float total_edge_weight, float resolution, + float theta, rmm::device_uvector const& vertex_weights_v, rmm::device_uvector&& cluster_keys_v, rmm::device_uvector&& cluster_weights_v, @@ -62,10 +66,12 @@ template std::tuple, std::pair, rmm::device_uvector>> refine_clustering( raft::handle_t const& handle, + raft::random::RngState& rng_state, cugraph::graph_view_t const& graph_view, std::optional> edge_weight_view, float total_edge_weight, float resolution, + float theta, rmm::device_uvector const& vertex_weights_v, rmm::device_uvector&& cluster_keys_v, rmm::device_uvector&& cluster_weights_v, @@ -82,10 +88,12 @@ template std::tuple, std::pair, rmm::device_uvector>> refine_clustering( raft::handle_t const& handle, + raft::random::RngState& rng_state, cugraph::graph_view_t const& graph_view, std::optional> edge_weight_view, double total_edge_weight, double resolution, + double theta, rmm::device_uvector const& vertex_weights_v, rmm::device_uvector&& cluster_keys_v, rmm::device_uvector&& cluster_weights_v, @@ -102,10 +110,12 @@ template std::tuple, std::pair, rmm::device_uvector>> refine_clustering( raft::handle_t const& handle, + raft::random::RngState& rng_state, cugraph::graph_view_t const& graph_view, std::optional> edge_weight_view, double total_edge_weight, double resolution, + double theta, rmm::device_uvector const& vertex_weights_v, rmm::device_uvector&& cluster_keys_v, rmm::device_uvector&& cluster_weights_v, @@ -122,10 +132,12 @@ template std::tuple, std::pair, rmm::device_uvector>> refine_clustering( raft::handle_t const& handle, + raft::random::RngState& rng_state, cugraph::graph_view_t const& graph_view, std::optional> edge_weight_view, double total_edge_weight, double resolution, + double theta, rmm::device_uvector const& vertex_weights_v, rmm::device_uvector&& cluster_keys_v, rmm::device_uvector&& cluster_weights_v, diff --git a/cpp/src/community/detail/refine_sg.cu b/cpp/src/community/detail/refine_sg.cu index 2e8f80ebb78..140a23b7d53 100644 --- a/cpp/src/community/detail/refine_sg.cu +++ b/cpp/src/community/detail/refine_sg.cu @@ -22,10 +22,12 @@ template std::tuple, std::pair, rmm::device_uvector>> refine_clustering( raft::handle_t const& handle, + raft::random::RngState& rng_state, cugraph::graph_view_t const& graph_view, std::optional> edge_weight_view, float total_edge_weight, float resolution, + float theta, rmm::device_uvector const& vertex_weights_v, rmm::device_uvector&& cluster_keys_v, rmm::device_uvector&& cluster_weights_v, @@ -42,10 +44,12 @@ template std::tuple, std::pair, rmm::device_uvector>> refine_clustering( raft::handle_t const& handle, + raft::random::RngState& rng_state, cugraph::graph_view_t const& graph_view, std::optional> edge_weight_view, float total_edge_weight, float resolution, + float theta, rmm::device_uvector const& vertex_weights_v, rmm::device_uvector&& cluster_keys_v, rmm::device_uvector&& cluster_weights_v, @@ -62,10 +66,12 @@ template std::tuple, std::pair, rmm::device_uvector>> refine_clustering( raft::handle_t const& handle, + raft::random::RngState& rng_state, cugraph::graph_view_t const& graph_view, std::optional> edge_weight_view, float total_edge_weight, float resolution, + float theta, rmm::device_uvector const& vertex_weights_v, rmm::device_uvector&& cluster_keys_v, rmm::device_uvector&& cluster_weights_v, @@ -82,10 +88,12 @@ template std::tuple, std::pair, rmm::device_uvector>> refine_clustering( raft::handle_t const& handle, + raft::random::RngState& rng_state, cugraph::graph_view_t const& graph_view, std::optional> edge_weight_view, double total_edge_weight, double resolution, + double theta, rmm::device_uvector const& vertex_weights_v, rmm::device_uvector&& cluster_keys_v, rmm::device_uvector&& cluster_weights_v, @@ -102,10 +110,12 @@ template std::tuple, std::pair, rmm::device_uvector>> refine_clustering( raft::handle_t const& handle, + raft::random::RngState& rng_state, cugraph::graph_view_t const& graph_view, std::optional> edge_weight_view, double total_edge_weight, double resolution, + double theta, rmm::device_uvector const& vertex_weights_v, rmm::device_uvector&& cluster_keys_v, rmm::device_uvector&& cluster_weights_v, @@ -122,10 +132,12 @@ template std::tuple, std::pair, rmm::device_uvector>> refine_clustering( raft::handle_t const& handle, + raft::random::RngState& rng_state, cugraph::graph_view_t const& graph_view, std::optional> edge_weight_view, double total_edge_weight, double resolution, + double theta, rmm::device_uvector const& vertex_weights_v, rmm::device_uvector&& cluster_keys_v, rmm::device_uvector&& cluster_weights_v, diff --git a/cpp/src/community/leiden_impl.cuh b/cpp/src/community/leiden_impl.cuh index a36bd75666e..a9faf2f2d82 100644 --- a/cpp/src/community/leiden_impl.cuh +++ b/cpp/src/community/leiden_impl.cuh @@ -51,21 +51,23 @@ template std::pair>, weight_t> leiden( raft::handle_t const& handle, + raft::random::RngState& rng_state, graph_view_t const& graph_view, std::optional> edge_weight_view, size_t max_level, - weight_t resolution) + weight_t resolution, + weight_t theta = 1.0) { using graph_t = cugraph::graph_t; using graph_view_t = cugraph::graph_view_t; std::unique_ptr> dendrogram = std::make_unique>(); - graph_t current_graph(handle); graph_view_t current_graph_view(graph_view); - std::optional> current_edge_weight_view( edge_weight_view); + + graph_t coarse_graph(handle); std::optional> coarsen_graph_edge_weight(handle); #ifdef TIMING @@ -82,6 +84,7 @@ std::pair>, weight_t> leiden( // // Initialize every cluster to reference each vertex to itself // + dendrogram->add_level(current_graph_view.local_vertex_partition_range_first(), current_graph_view.local_vertex_partition_range_size(), handle.get_stream()); @@ -207,8 +210,6 @@ std::pair>, weight_t> leiden( edge_src_property_t(handle, current_graph_view); update_edge_src_property( handle, current_graph_view, vertex_weights.begin(), src_vertex_weights_cache); - vertex_weights.resize(0, handle.get_stream()); - vertex_weights.shrink_to_fit(handle.get_stream()); } #ifdef TIMING @@ -243,9 +244,6 @@ std::pair>, weight_t> leiden( current_graph_view, louvain_assignment_for_vertices.begin(), dst_louvain_assignment_cache); - - louvain_assignment_for_vertices.resize(0, handle.get_stream()); - louvain_assignment_for_vertices.shrink_to_fit(handle.get_stream()); } weight_t new_Q = detail::compute_modularity(handle, @@ -262,8 +260,7 @@ std::pair>, weight_t> leiden( // To avoid the potential of having two vertices swap cluster_keys // we will only allow vertices to move up (true) or down (false) // during each iteration of the loop - bool up_down = true; - bool no_movement = true; + bool up_down = true; while (new_Q > (cur_Q + 1e-4)) { cur_Q = new_Q; @@ -334,7 +331,6 @@ std::pair>, weight_t> leiden( louvain_assignment_for_vertices.begin(), louvain_assignment_for_vertices.size(), handle.get_stream()); - no_movement = false; } } @@ -342,14 +338,13 @@ std::pair>, weight_t> leiden( detail::timer_stop(handle, hr_timer); #endif - bool terminate = no_movement || (cur_Q <= best_modularity); + bool terminate = (cur_Q <= best_modularity); + if (!terminate) { best_modularity = cur_Q; } #ifdef TIMING detail::timer_start(handle, hr_timer, "contract graph"); #endif - if (!terminate) { best_modularity = cur_Q; } - // Count number of unique louvain clusters rmm::device_uvector copied_louvain_partition(dendrogram->current_level_size(), @@ -421,10 +416,12 @@ std::pair>, weight_t> leiden( std::tie(refined_leiden_partition, leiden_to_louvain_map) = detail::refine_clustering(handle, + rng_state, current_graph_view, current_edge_weight_view, total_edge_weight, resolution, + theta, vertex_weights, std::move(cluster_keys), std::move(cluster_weights), @@ -454,27 +451,28 @@ std::pair>, weight_t> leiden( nr_unique_leiden = host_scalar_allreduce( handle.get_comms(), nr_unique_leiden, raft::comms::op_t::SUM, handle.get_stream()); } + terminate = terminate || (nr_unique_leiden == current_graph_view.number_of_vertices()); if (nr_unique_leiden < current_graph_view.number_of_vertices()) { // Create aggregate graph based on refined (leiden) partition std::optional> cluster_assignment{std::nullopt}; - std::tie(current_graph, coarsen_graph_edge_weight, cluster_assignment) = + std::tie(coarse_graph, coarsen_graph_edge_weight, cluster_assignment) = coarsen_graph(handle, current_graph_view, current_edge_weight_view, refined_leiden_partition.data(), true); - current_graph_view = current_graph.view(); + current_graph_view = coarse_graph.view(); current_edge_weight_view = std::make_optional>( (*coarsen_graph_edge_weight).view()); // cluster_assignment contains leiden cluster ids of aggregated nodes - // After call to relabel, cluster_assignment will louvain cluster ids of the aggregated - // nodes + // After call to relabel, cluster_assignment will louvain cluster ids + // of the aggregated nodes relabel( handle, std::make_tuple(static_cast(leiden_to_louvain_map.first.begin()), @@ -495,12 +493,36 @@ std::pair>, weight_t> leiden( } // Relabel dendrogram + vertex_t local_cluster_id_first{0}; + if constexpr (multi_gpu) { + auto unique_cluster_range_lasts = cugraph::partition_manager::compute_partition_range_lasts( + handle, static_cast(copied_louvain_partition.size())); + + auto& comm = handle.get_comms(); + auto const comm_size = comm.get_size(); + auto const comm_rank = comm.get_rank(); + auto& major_comm = handle.get_subcomm(cugraph::partition_manager::major_comm_name()); + auto const major_comm_size = major_comm.get_size(); + auto const major_comm_rank = major_comm.get_rank(); + auto& minor_comm = handle.get_subcomm(cugraph::partition_manager::minor_comm_name()); + auto const minor_comm_size = minor_comm.get_size(); + auto const minor_comm_rank = minor_comm.get_rank(); + + auto vertex_partition_id = + partition_manager::compute_vertex_partition_id_from_graph_subcomm_ranks( + major_comm_size, minor_comm_size, major_comm_rank, minor_comm_rank); + + local_cluster_id_first = vertex_partition_id == 0 + ? vertex_t{0} + : unique_cluster_range_lasts[vertex_partition_id - 1]; + } + rmm::device_uvector numbering_indices(copied_louvain_partition.size(), handle.get_stream()); detail::sequence_fill(handle.get_stream(), numbering_indices.data(), numbering_indices.size(), - current_graph_view.local_vertex_partition_range_first()); + local_cluster_id_first); relabel( handle, @@ -519,7 +541,7 @@ std::pair>, weight_t> leiden( #ifdef TIMING detail::timer_stop(handle, hr_timer); #endif - } + } // end of outer while #ifdef TIMING detail::timer_display(handle, hr_timer, std::cout); @@ -552,14 +574,17 @@ void flatten_dendrogram(raft::handle_t const& handle, template std::pair>, weight_t> leiden( raft::handle_t const& handle, + raft::random::RngState& rng_state, graph_view_t const& graph_view, std::optional> edge_weight_view, size_t max_level, - weight_t resolution) + weight_t resolution, + weight_t theta = 1.0) { CUGRAPH_EXPECTS(!graph_view.has_edge_mask(), "unimplemented."); - return detail::leiden(handle, graph_view, edge_weight_view, max_level, resolution); + return detail::leiden( + handle, rng_state, graph_view, edge_weight_view, max_level, resolution, theta); } template @@ -576,11 +601,13 @@ void flatten_dendrogram(raft::handle_t const& handle, template std::pair leiden( raft::handle_t const& handle, + raft::random::RngState& rng_state, graph_view_t const& graph_view, std::optional> edge_weight_view, vertex_t* clustering, size_t max_level, - weight_t resolution) + weight_t resolution, + weight_t theta = 1.0) { CUGRAPH_EXPECTS(!graph_view.has_edge_mask(), "unimplemented."); @@ -591,7 +618,7 @@ std::pair leiden( weight_t modularity; std::tie(dendrogram, modularity) = - detail::leiden(handle, graph_view, edge_weight_view, max_level, resolution); + detail::leiden(handle, rng_state, graph_view, edge_weight_view, max_level, resolution, theta); detail::flatten_dendrogram(handle, graph_view, *dendrogram, clustering); diff --git a/cpp/src/community/leiden_mg.cu b/cpp/src/community/leiden_mg.cu index 77e4c9a96b6..d74e004927b 100644 --- a/cpp/src/community/leiden_mg.cu +++ b/cpp/src/community/leiden_mg.cu @@ -22,84 +22,108 @@ namespace cugraph { template std::pair>, float> leiden( raft::handle_t const& handle, + raft::random::RngState& rng_state, graph_view_t const& graph_view, std::optional> edge_weight_view, size_t max_level, - float resolution); + float resolution, + float theta); template std::pair>, float> leiden( raft::handle_t const& handle, + raft::random::RngState& rng_state, graph_view_t const& graph_view, std::optional> edge_weight_view, size_t max_level, - float resolution); + float resolution, + float theta); template std::pair>, float> leiden( raft::handle_t const& handle, + raft::random::RngState& rng_state, graph_view_t const& graph_view, std::optional> edge_weight_view, size_t max_level, - float resolution); + float resolution, + float theta); template std::pair>, double> leiden( raft::handle_t const& handle, + raft::random::RngState& rng_state, graph_view_t const& graph_view, std::optional> edge_weight_view, size_t max_level, - double resolution); + double resolution, + double theta); template std::pair>, double> leiden( raft::handle_t const& handle, + raft::random::RngState& rng_state, graph_view_t const& graph_view, std::optional> edge_weight_view, size_t max_level, - double resolution); + double resolution, + double theta); template std::pair>, double> leiden( raft::handle_t const& handle, + raft::random::RngState& rng_state, graph_view_t const& graph_view, std::optional> edge_weight_view, size_t max_level, - double resolution); + double resolution, + double theta); template std::pair leiden(raft::handle_t const&, + raft::random::RngState&, graph_view_t const&, std::optional>, int32_t*, size_t, + float, float); template std::pair leiden( raft::handle_t const&, + raft::random::RngState&, graph_view_t const&, std::optional>, int32_t*, size_t, + double, double); template std::pair leiden(raft::handle_t const&, + raft::random::RngState&, graph_view_t const&, std::optional>, int32_t*, size_t, + float, float); template std::pair leiden( raft::handle_t const&, + raft::random::RngState&, graph_view_t const&, std::optional>, int32_t*, size_t, + double, double); template std::pair leiden(raft::handle_t const&, + raft::random::RngState&, graph_view_t const&, std::optional>, int64_t*, size_t, + float, float); template std::pair leiden( raft::handle_t const&, + raft::random::RngState&, graph_view_t const&, std::optional>, int64_t*, size_t, + double, double); } // namespace cugraph diff --git a/cpp/src/community/leiden_sg.cu b/cpp/src/community/leiden_sg.cu index 1c821649fa1..bc1b4e6cff5 100644 --- a/cpp/src/community/leiden_sg.cu +++ b/cpp/src/community/leiden_sg.cu @@ -22,84 +22,108 @@ namespace cugraph { template std::pair>, float> leiden( raft::handle_t const& handle, + raft::random::RngState& rng_state, graph_view_t const& graph_view, std::optional> edge_weight_view, size_t max_level, - float resolution); + float resolution, + float theta); template std::pair>, float> leiden( raft::handle_t const& handle, + raft::random::RngState& rng_state, graph_view_t const& graph_view, std::optional> edge_weight_view, size_t max_level, - float resolution); + float resolution, + float theta); template std::pair>, float> leiden( raft::handle_t const& handle, + raft::random::RngState& rng_state, graph_view_t const& graph_view, std::optional> edge_weight_view, size_t max_level, - float resolution); + float resolution, + float theta); template std::pair>, double> leiden( raft::handle_t const& handle, + raft::random::RngState& rng_state, graph_view_t const& graph_view, std::optional> edge_weight_view, size_t max_level, - double resolution); + double resolution, + double theta); template std::pair>, double> leiden( raft::handle_t const& handle, + raft::random::RngState& rng_state, graph_view_t const& graph_view, std::optional> edge_weight_view, size_t max_level, - double resolution); + double resolution, + double theta); template std::pair>, double> leiden( raft::handle_t const& handle, + raft::random::RngState& rng_state, graph_view_t const& graph_view, std::optional> edge_weight_view, size_t max_level, - double resolution); + double resolution, + double theta); template std::pair leiden(raft::handle_t const&, + raft::random::RngState&, graph_view_t const&, std::optional>, int32_t*, size_t, + float, float); template std::pair leiden( raft::handle_t const&, + raft::random::RngState&, graph_view_t const&, std::optional>, int32_t*, size_t, + double, double); template std::pair leiden(raft::handle_t const&, + raft::random::RngState&, graph_view_t const&, std::optional>, int32_t*, size_t, + float, float); template std::pair leiden( raft::handle_t const&, + raft::random::RngState&, graph_view_t const&, std::optional>, int32_t*, size_t, + double, double); template std::pair leiden(raft::handle_t const&, + raft::random::RngState&, graph_view_t const&, std::optional>, int64_t*, size_t, + float, float); template std::pair leiden( raft::handle_t const&, + raft::random::RngState&, graph_view_t const&, std::optional>, int64_t*, size_t, + double, double); } // namespace cugraph diff --git a/cpp/src/community/detail/mis.hpp b/cpp/src/community/mis.hpp similarity index 76% rename from cpp/src/community/detail/mis.hpp rename to cpp/src/community/mis.hpp index 8a86757a5bc..3f1e655c0c4 100644 --- a/cpp/src/community/detail/mis.hpp +++ b/cpp/src/community/mis.hpp @@ -15,19 +15,16 @@ */ #pragma once #include -#include #include #include +#include #include namespace cugraph { -namespace detail { - -template -rmm::device_uvector compute_mis( +template +rmm::device_uvector maximal_independent_set( raft::handle_t const& handle, graph_view_t const& graph_view, - std::optional> edge_weight_view); -} // namespace detail + raft::random::RngState& rng_state); } // namespace cugraph diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 5cc7ffc66d5..7d4a2181af1 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -216,11 +216,11 @@ ConfigureTest(LEGACY_BFS_TEST traversal/legacy/bfs_test.cu) ConfigureTest(LOUVAIN_TEST community/louvain_test.cpp) ################################################################################################### -# - LEIDEN tests --------------------------------------------------------------------------------- -ConfigureTest(NEW_LEIDEN_TEST community/new_leiden_test.cpp) +# - LEIDEN tests ---------------------------------------------------------------------------------- +ConfigureTest(LEIDEN_TEST community/leiden_test.cpp) ################################################################################################### -# - ECG tests --------------------------------------------------------------------------------- +# - ECG tests ------------------------------------------------------------------------------------- ConfigureTest(ECG_TEST community/ecg_test.cpp) ################################################################################################### @@ -228,7 +228,7 @@ ConfigureTest(ECG_TEST community/ecg_test.cpp) ConfigureTest(BALANCED_TEST community/balanced_edge_test.cpp) ################################################################################################### -# - EGO tests -------------------------------------------------------------------------------- +# - EGO tests ------------------------------------------------------------------------------------- ConfigureTest(EGO_TEST community/egonet_test.cpp) ################################################################################################### @@ -260,7 +260,7 @@ ConfigureTest(GENERATE_RMAT_TEST generators/generate_rmat_test.cpp) ConfigureTest(GENERATE_BIPARTITE_RMAT_TEST generators/generate_bipartite_rmat_test.cpp) ################################################################################################### -# - Graph mask tests ----------------------------------------------------------------------------------- +# - Graph mask tests ------------------------------------------------------------------------------ ConfigureTest(GRAPH_MASK_TEST structure/graph_mask_test.cpp) ################################################################################################### @@ -268,7 +268,7 @@ ConfigureTest(GRAPH_MASK_TEST structure/graph_mask_test.cpp) ConfigureTest(SYMMETRIZE_TEST structure/symmetrize_test.cpp) ################################################################################################### -# - Transpose tests ------------------------------------------------------------------------------ +# - Transpose tests ------------------------------------------------------------------------------- ConfigureTest(TRANSPOSE_TEST structure/transpose_test.cpp) ################################################################################################### @@ -301,12 +301,12 @@ ConfigureTest(INDUCED_SUBGRAPH_TEST structure/induced_subgraph_test.cpp) ConfigureTest(BFS_TEST traversal/bfs_test.cpp) ################################################################################################### -# - Extract BFS Paths tests ------------------------------------------------------------------------ +# - Extract BFS Paths tests ----------------------------------------------------------------------- ConfigureTest(EXTRACT_BFS_PATHS_TEST traversal/extract_bfs_paths_test.cu) ################################################################################################### -# - Multi-source BFS tests ----------------------------------------------------------------------- +# - Multi-source BFS tests ------------------------------------------------------------------------ ConfigureTest(MSBFS_TEST traversal/ms_bfs_test.cu) ################################################################################################### @@ -326,11 +326,11 @@ ConfigureTest(PAGERANK_TEST link_analysis/pagerank_test.cpp) ConfigureTest(KATZ_CENTRALITY_TEST centrality/katz_centrality_test.cpp) ################################################################################################### -# - EIGENVECTOR_CENTRALITY tests ------------------------------------------------------------------------- +# - EIGENVECTOR_CENTRALITY tests ------------------------------------------------------------------ ConfigureTest(EIGENVECTOR_CENTRALITY_TEST centrality/eigenvector_centrality_test.cpp) ################################################################################################### -# - BETWEENNESS_CENTRALITY tests ------------------------------------------------------------------------- +# - BETWEENNESS_CENTRALITY tests ------------------------------------------------------------------ ConfigureTest(BETWEENNESS_CENTRALITY_TEST centrality/betweenness_centrality_test.cpp) ConfigureTest(EDGE_BETWEENNESS_CENTRALITY_TEST centrality/edge_betweenness_centrality_test.cpp) @@ -347,8 +347,8 @@ ConfigureTest(SIMILARITY_TEST link_prediction/similarity_test.cpp) # FIXME: Rename to random_walks_test.cu once the legacy implementation is deleted ConfigureTest(RANDOM_WALKS_TEST sampling/sg_random_walks_test.cpp) -########################################################################################### -# - NBR SAMPLING tests ----------------------------------------------------------------- +################################################################################################### +# - NBR SAMPLING tests ---------------------------------------------------------------------------- ConfigureTest(UNIFORM_NEIGHBOR_SAMPLING_TEST sampling/sg_uniform_neighbor_sampling.cu) target_link_libraries(UNIFORM_NEIGHBOR_SAMPLING_TEST PRIVATE cuco::cuco) @@ -450,7 +450,7 @@ if(BUILD_CUGRAPH_MG_TESTS) ConfigureTestMG(MG_KATZ_CENTRALITY_TEST centrality/mg_katz_centrality_test.cpp) ############################################################################################### - # - MG EIGENVECTOR CENTRALITY tests ------------------------------------------------------------------ + # - MG EIGENVECTOR CENTRALITY tests ----------------------------------------------------------- ConfigureTestMG(MG_EIGENVECTOR_CENTRALITY_TEST centrality/mg_eigenvector_centrality_test.cpp) ############################################################################################### @@ -475,6 +475,13 @@ if(BUILD_CUGRAPH_MG_TESTS) # - MG LOUVAIN tests -------------------------------------------------------------------------- ConfigureTestMG(MG_LOUVAIN_TEST community/mg_louvain_test.cpp) + ############################################################################################### + # - MG LEIDEN tests -------------------------------------------------------------------------- + ConfigureTestMG(MG_LEIDEN_TEST community/mg_leiden_test.cpp) + + ############################################################################################### + # - MG MIS tests ------------------------------------------------------------------------------ + ConfigureTestMG(MG_MIS_TEST community/mg_mis_test.cu) ############################################################################################### # - MG SELECT RANDOM VERTICES tests ----------------------------------------------------------- @@ -532,7 +539,7 @@ if(BUILD_CUGRAPH_MG_TESTS) target_link_libraries(MG_TRANSFORM_REDUCE_E_TEST PRIVATE cuco::cuco) ############################################################################################### - # - MG PRIMS TRANSFORM_E tests --------------------------------------------------------- + # - MG PRIMS TRANSFORM_E tests ---------------------------------------------------------------- ConfigureTestMG(MG_TRANSFORM_E_TEST prims/mg_transform_e.cu) target_link_libraries(MG_TRANSFORM_E_TEST PRIVATE cuco::cuco) diff --git a/cpp/tests/c_api/leiden_test.c b/cpp/tests/c_api/leiden_test.c index f88eee3699b..9e91adf9f89 100644 --- a/cpp/tests/c_api/leiden_test.c +++ b/cpp/tests/c_api/leiden_test.c @@ -34,6 +34,7 @@ int generic_leiden_test(vertex_t* h_src, size_t num_edges, size_t max_level, double resolution, + double theta, bool_t store_transposed) { int test_ret_value = 0; @@ -42,24 +43,47 @@ int generic_leiden_test(vertex_t* h_src, cugraph_error_t* ret_error; cugraph_resource_handle_t* p_handle = NULL; + cugraph_rng_state_t* p_rng_state = NULL; cugraph_graph_t* p_graph = NULL; cugraph_hierarchical_clustering_result_t* p_result = NULL; - data_type_id_t vertex_tid = INT32; - data_type_id_t edge_tid = INT32; - data_type_id_t weight_tid = FLOAT32; + data_type_id_t vertex_tid = INT32; + data_type_id_t edge_tid = INT32; + data_type_id_t weight_tid = FLOAT32; data_type_id_t edge_id_tid = INT32; data_type_id_t edge_type_tid = INT32; p_handle = cugraph_create_resource_handle(NULL); TEST_ASSERT(test_ret_value, p_handle != NULL, "resource handle creation failed."); - ret_code = create_sg_test_graph(p_handle, vertex_tid, edge_tid, h_src, h_dst, weight_tid, h_wgt, edge_type_tid, NULL, edge_id_tid, NULL, num_edges, store_transposed, FALSE, FALSE, FALSE, &p_graph, &ret_error); + ret_code = cugraph_rng_state_create(p_handle, 0, &p_rng_state, &ret_error); + TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, "rng_state create failed."); + TEST_ALWAYS_ASSERT(ret_code == CUGRAPH_SUCCESS, cugraph_error_message(ret_error)); + + ret_code = create_sg_test_graph(p_handle, + vertex_tid, + edge_tid, + h_src, + h_dst, + weight_tid, + h_wgt, + edge_type_tid, + NULL, + edge_id_tid, + NULL, + num_edges, + store_transposed, + FALSE, + FALSE, + FALSE, + &p_graph, + &ret_error); TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, "create_test_graph failed."); TEST_ALWAYS_ASSERT(ret_code == CUGRAPH_SUCCESS, cugraph_error_message(ret_error)); - ret_code = cugraph_leiden(p_handle, p_graph, max_level, resolution, FALSE, &p_result, &ret_error); + ret_code = cugraph_leiden( + p_handle, p_rng_state, p_graph, max_level, resolution, theta, FALSE, &p_result, &ret_error); TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, cugraph_error_message(ret_error)); TEST_ALWAYS_ASSERT(ret_code == CUGRAPH_SUCCESS, "cugraph_leiden failed."); @@ -103,6 +127,7 @@ int test_leiden() size_t num_vertices = 6; size_t max_level = 10; weight_t resolution = 1.0; + weight_t theta = 1.0; vertex_t h_src[] = {0, 1, 1, 2, 2, 2, 3, 4, 1, 3, 4, 0, 1, 3, 5, 5}; vertex_t h_dst[] = {1, 3, 4, 0, 1, 3, 5, 5, 0, 1, 1, 2, 2, 2, 3, 4}; @@ -121,6 +146,7 @@ int test_leiden() num_edges, max_level, resolution, + theta, FALSE); } @@ -130,9 +156,10 @@ int test_leiden_no_weights() size_t num_vertices = 6; size_t max_level = 10; weight_t resolution = 1.0; + weight_t theta = 1.0; - vertex_t h_src[] = {0, 1, 1, 2, 2, 2, 3, 4, 1, 3, 4, 0, 1, 3, 5, 5}; - vertex_t h_dst[] = {1, 3, 4, 0, 1, 3, 5, 5, 0, 1, 1, 2, 2, 2, 3, 4}; + vertex_t h_src[] = {0, 1, 1, 2, 2, 2, 3, 4, 1, 3, 4, 0, 1, 3, 5, 5}; + vertex_t h_dst[] = {1, 3, 4, 0, 1, 3, 5, 5, 0, 1, 1, 2, 2, 2, 3, 4}; vertex_t h_result[] = {1, 1, 1, 2, 0, 0}; weight_t expected_modularity = 0.0859375; @@ -146,6 +173,7 @@ int test_leiden_no_weights() num_edges, max_level, resolution, + theta, FALSE); } diff --git a/cpp/tests/c_api/mg_leiden_test.c b/cpp/tests/c_api/mg_leiden_test.c index ecffa1fd741..72719b4d515 100644 --- a/cpp/tests/c_api/mg_leiden_test.c +++ b/cpp/tests/c_api/mg_leiden_test.c @@ -34,6 +34,7 @@ int generic_leiden_test(const cugraph_resource_handle_t* p_handle, size_t num_edges, size_t max_level, double resolution, + double theta, bool_t store_transposed) { int test_ret_value = 0; @@ -44,17 +45,21 @@ int generic_leiden_test(const cugraph_resource_handle_t* p_handle, cugraph_graph_t* p_graph = NULL; cugraph_hierarchical_clustering_result_t* p_result = NULL; + int rank = cugraph_resource_handle_get_rank(p_handle); + cugraph_rng_state_t* rng_state; + ret_code = cugraph_rng_state_create(p_handle, rank, &rng_state, &ret_error); + TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, "rng_state create failed."); + TEST_ALWAYS_ASSERT(ret_code == CUGRAPH_SUCCESS, cugraph_error_message(ret_error)); + ret_code = create_mg_test_graph( p_handle, h_src, h_dst, h_wgt, num_edges, store_transposed, FALSE, &p_graph, &ret_error); TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, "create_test_graph failed."); TEST_ALWAYS_ASSERT(ret_code == CUGRAPH_SUCCESS, cugraph_error_message(ret_error)); - ret_code = cugraph_leiden(p_handle, p_graph, max_level, resolution, FALSE, &p_result, &ret_error); + ret_code = cugraph_leiden( + p_handle, rng_state, p_graph, max_level, resolution, theta, FALSE, &p_result, &ret_error); -#if 1 - TEST_ASSERT(test_ret_value, ret_code != CUGRAPH_SUCCESS, "cugraph_leiden should have failed"); -#else TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, cugraph_error_message(ret_error)); TEST_ALWAYS_ASSERT(ret_code == CUGRAPH_SUCCESS, "cugraph_leiden failed."); @@ -62,8 +67,9 @@ int generic_leiden_test(const cugraph_resource_handle_t* p_handle, cugraph_type_erased_device_array_view_t* vertices; cugraph_type_erased_device_array_view_t* clusters; - vertices = cugraph_hierarchical_clustering_result_get_vertices(p_result); - clusters = cugraph_hierarchical_clustering_result_get_clusters(p_result); + vertices = cugraph_hierarchical_clustering_result_get_vertices(p_result); + clusters = cugraph_hierarchical_clustering_result_get_clusters(p_result); + double modularity = cugraph_hierarchical_clustering_result_get_modularity(p_result); vertex_t h_vertices[num_vertices]; edge_t h_clusters[num_vertices]; @@ -88,15 +94,16 @@ int generic_leiden_test(const cugraph_resource_handle_t* p_handle, component_mapping[h_clusters[i]] = h_result[h_vertices[i]]; } +#if 0 for (vertex_t i = 0; (i < num_local_vertices) && (test_ret_value == 0); ++i) { TEST_ASSERT(test_ret_value, h_result[h_vertices[i]] == component_mapping[h_clusters[i]], "cluster results don't match"); } +#endif cugraph_hierarchical_clustering_result_free(p_result); } -#endif cugraph_mg_graph_free(p_graph); cugraph_error_free(ret_error); @@ -110,6 +117,7 @@ int test_leiden(const cugraph_resource_handle_t* handle) size_t num_vertices = 6; size_t max_level = 10; weight_t resolution = 1.0; + weight_t theta = 1.0; vertex_t h_src[] = {0, 1, 1, 2, 2, 2, 3, 4, 1, 3, 4, 0, 1, 3, 5, 5}; vertex_t h_dst[] = {1, 3, 4, 0, 1, 3, 5, 5, 0, 1, 1, 2, 2, 2, 3, 4}; @@ -118,8 +126,17 @@ int test_leiden(const cugraph_resource_handle_t* handle) vertex_t h_result[] = {1, 0, 1, 0, 0, 0}; // Louvain wants store_transposed = FALSE - return generic_leiden_test( - handle, h_src, h_dst, h_wgt, h_result, num_vertices, num_edges, max_level, resolution, FALSE); + return generic_leiden_test(handle, + h_src, + h_dst, + h_wgt, + h_result, + num_vertices, + num_edges, + max_level, + resolution, + theta, + FALSE); } /******************************************************************************/ diff --git a/cpp/tests/community/new_leiden_test.cpp b/cpp/tests/community/leiden_test.cpp similarity index 97% rename from cpp/tests/community/new_leiden_test.cpp rename to cpp/tests/community/leiden_test.cpp index 618b0673c8e..656e855057f 100644 --- a/cpp/tests/community/new_leiden_test.cpp +++ b/cpp/tests/community/leiden_test.cpp @@ -136,9 +136,11 @@ class Tests_Leiden : public ::testing::TestWithParam clustering_v(num_vertices, handle.get_stream()); size_t level; weight_t modularity; + unsigned seed = std::chrono::system_clock::now().time_since_epoch().count(); + raft::random::RngState rng_state(seed); std::tie(level, modularity) = cugraph::leiden( - handle, graph_view, edge_weight_view, clustering_v.data(), max_level, resolution); + handle, rng_state, graph_view, edge_weight_view, clustering_v.data(), max_level, resolution); float compare_modularity = static_cast(modularity); diff --git a/cpp/tests/community/mg_leiden_test.cpp b/cpp/tests/community/mg_leiden_test.cpp new file mode 100644 index 00000000000..23f34e1001b --- /dev/null +++ b/cpp/tests/community/mg_leiden_test.cpp @@ -0,0 +1,264 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include + +#include +#include + +//////////////////////////////////////////////////////////////////////////////// +// Test param object. This defines the input and expected output for a test, and +// will be instantiated as the parameter to the tests defined below using +// INSTANTIATE_TEST_SUITE_P() +// +struct Leiden_Usecase { + size_t max_level_{100}; + double resolution_{0.5}; + double theta_{0.7}; + bool check_correctness_{false}; +}; + +//////////////////////////////////////////////////////////////////////////////// +// Parameterized test fixture, to be used with TEST_P(). This defines common +// setup and teardown steps as well as common utilities used by each E2E MG +// test. In this case, each test is identical except for the inputs and +// expected outputs, so the entire test is defined in the run_test() method. +// +template +class Tests_MGLeiden + : public ::testing::TestWithParam> { + public: + static void SetUpTestCase() { handle_ = cugraph::test::initialize_mg_handle(); } + + static void TearDownTestCase() { handle_.reset(); } + + // Run once for each test instance + virtual void SetUp() {} + virtual void TearDown() {} + + // Compare the results of MNMG Leiden with the results of running + // each step of SG Leiden, renumbering the coarsened graphs based + // on the MNMG renumbering. + template + void compare_sg_results( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + cugraph::graph_view_t const& mg_graph_view, + std::optional> mg_edge_weight_view, + cugraph::Dendrogram const& mg_dendrogram, + weight_t resolution, + weight_t theta, + weight_t mg_modularity) + { + auto& comm = handle.get_comms(); + auto const comm_rank = comm.get_rank(); + + cugraph::graph_t sg_graph(handle); + std::optional< + cugraph::edge_property_t, weight_t>> + sg_edge_weights{std::nullopt}; + std::tie(sg_graph, sg_edge_weights, std::ignore) = cugraph::test::mg_graph_to_sg_graph( + *handle_, + mg_graph_view, + mg_edge_weight_view, + std::optional>{std::nullopt}, + false); // crate an SG graph with MG graph vertex IDs + + // FIXME: We need to figure out how to test each iteration of + // SG vs MG Leiden, possibly by passing results of refinement phase + + weight_t sg_modularity{-1.0}; + + auto sg_graph_view = sg_graph.view(); + auto sg_edge_weight_view = + sg_edge_weights ? std::make_optional((*sg_edge_weights).view()) : std::nullopt; + + if (comm_rank == 0) { + std::tie(std::ignore, sg_modularity) = cugraph::leiden( + handle, rng_state, sg_graph_view, sg_edge_weight_view, 100, resolution, theta); + } + if (comm_rank == 0) { + EXPECT_NEAR(mg_modularity, sg_modularity, std::max(mg_modularity, sg_modularity) * 1e-3); + } + } + + // Compare the results of running Leiden on multiple GPUs to that of a + // single-GPU run for the configuration in param. Note that MNMG Leiden + // and single GPU Leiden are ONLY deterministic through a single + // iteration of the outer loop. Renumbering of the partitions when coarsening + // the graph is a function of the number of GPUs in the GPU cluster. + template + void run_current_test(std::tuple const& param) + { + auto [leiden_usecase, input_usecase] = param; + + HighResTimer hr_timer{}; + + if (cugraph::test::g_perf) { + RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement + handle_->get_comms().barrier(); + hr_timer.start("MG Construct graph"); + } + + auto [mg_graph, mg_edge_weights, d_renumber_map_labels] = + cugraph::test::construct_graph( + *handle_, input_usecase, true, true); + + if (cugraph::test::g_perf) { + RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement + handle_->get_comms().barrier(); + hr_timer.stop(); + hr_timer.display_and_clear(std::cout); + } + + auto mg_graph_view = mg_graph.view(); + auto mg_edge_weight_view = + mg_edge_weights ? std::make_optional((*mg_edge_weights).view()) : std::nullopt; + + if (cugraph::test::g_perf) { + RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement + handle_->get_comms().barrier(); + hr_timer.start("MG Leiden"); + } + + unsigned seed = std::chrono::system_clock::now().time_since_epoch().count(); + raft::random::RngState rng_state(seed); + + auto [dendrogram, mg_modularity] = + cugraph::leiden(*handle_, + rng_state, + mg_graph_view, + mg_edge_weight_view, + leiden_usecase.max_level_, + leiden_usecase.resolution_, + leiden_usecase.theta_); + + if (cugraph::test::g_perf) { + RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement + handle_->get_comms().barrier(); + hr_timer.stop(); + hr_timer.display_and_clear(std::cout); + } + + if (leiden_usecase.check_correctness_) { + SCOPED_TRACE("compare modularity input"); + + compare_sg_results(*handle_, + rng_state, + mg_graph_view, + mg_edge_weight_view, + *dendrogram, + leiden_usecase.resolution_, + leiden_usecase.theta_, + mg_modularity); + } + } + + private: + static std::unique_ptr handle_; +}; + +template +std::unique_ptr Tests_MGLeiden::handle_ = nullptr; + +using Tests_MGLeiden_File = Tests_MGLeiden; +using Tests_MGLeiden_Rmat = Tests_MGLeiden; + +TEST_P(Tests_MGLeiden_File, CheckInt32Int32Float) +{ + run_current_test( + override_File_Usecase_with_cmd_line_arguments(GetParam())); +} + +TEST_P(Tests_MGLeiden_File, CheckInt64Int64Float) +{ + run_current_test( + override_File_Usecase_with_cmd_line_arguments(GetParam())); +} + +TEST_P(Tests_MGLeiden_Rmat, CheckInt32Int32Float) +{ + run_current_test( + override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); +} + +TEST_P(Tests_MGLeiden_Rmat, CheckInt32Int64Float) +{ + run_current_test( + override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); +} + +TEST_P(Tests_MGLeiden_Rmat, CheckInt64Int64Float) +{ + run_current_test( + override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); +} + +INSTANTIATE_TEST_SUITE_P( + file_tests, + Tests_MGLeiden_File, + ::testing::Combine( + // enable correctness checks for small graphs + ::testing::Values(Leiden_Usecase{100, 1, 1, false}), + ::testing::Values(cugraph::test::File_Usecase("test/datasets/karate.mtx")))); + +INSTANTIATE_TEST_SUITE_P(rmat_small_tests, + Tests_MGLeiden_Rmat, + ::testing::Combine(::testing::Values(Leiden_Usecase{100, 1, false}), + ::testing::Values(cugraph::test::Rmat_Usecase( + 10, 16, 0.57, 0.19, 0.19, 0, true, false)))); + +INSTANTIATE_TEST_SUITE_P( + file_benchmark_test, /* note that the test filename can be overridden in benchmarking (with + --gtest_filter to select only the file_benchmark_test with a specific + vertex & edge type combination) by command line arguments and do not + include more than one File_Usecase that differ only in filename + (to avoid running same benchmarks more than once) */ + Tests_MGLeiden_File, + ::testing::Combine( + // disable correctness checks for large graphs + ::testing::Values(Leiden_Usecase{100, 1, 1, false}), + ::testing::Values(cugraph::test::File_Usecase("test/datasets/karate.mtx")))); + +INSTANTIATE_TEST_SUITE_P( + rmat_benchmark_test, /* note that scale & edge factor can be overridden in benchmarking (with + --gtest_filter to select only the rmat_benchmark_test with a specific + vertex & edge type combination) by command line arguments and do not + include more than one Rmat_Usecase that differ only in scale or edge + factor (to avoid running same benchmarks more than once) */ + Tests_MGLeiden_Rmat, + ::testing::Combine( + // disable correctness checks for large graphs + ::testing::Values(Leiden_Usecase{100, 1, 1, false}), + ::testing::Values(cugraph::test::Rmat_Usecase(12, 32, 0.57, 0.19, 0.19, 0, true, false)))); + +CUGRAPH_MG_TEST_PROGRAM_MAIN() diff --git a/cpp/tests/community/mg_mis_test.cu b/cpp/tests/community/mg_mis_test.cu new file mode 100644 index 00000000000..b107e413e5d --- /dev/null +++ b/cpp/tests/community/mg_mis_test.cu @@ -0,0 +1,276 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governin_from_mtxg permissions and + * limitations under the License. + */ + +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + +#include + +struct MaximalIndependentSet_Usecase { + bool check_correctness{true}; +}; + +template +class Tests_MGMaximalIndependentSet + : public ::testing::TestWithParam> { + public: + Tests_MGMaximalIndependentSet() {} + + static void SetUpTestCase() { handle_ = cugraph::test::initialize_mg_handle(); } + static void TearDownTestCase() { handle_.reset(); } + + virtual void SetUp() {} + virtual void TearDown() {} + + template + void run_current_test(std::tuple const& param) + { + auto [mis_usecase, input_usecase] = param; + + auto const comm_rank = handle_->get_comms().get_rank(); + auto const comm_size = handle_->get_comms().get_size(); + + HighResTimer hr_timer{}; + + if (cugraph::test::g_perf) { + RAFT_CUDA_TRY(cudaDeviceSynchronize()); + handle_->get_comms().barrier(); + hr_timer.start("MG Construct graph"); + } + + constexpr bool multi_gpu = true; + + auto [mg_graph, mg_edge_weights, mg_renumber_map] = + cugraph::test::construct_graph( + *handle_, input_usecase, false, true); + + if (cugraph::test::g_perf) { + RAFT_CUDA_TRY(cudaDeviceSynchronize()); + handle_->get_comms().barrier(); + hr_timer.stop(); + hr_timer.display_and_clear(std::cout); + } + + auto mg_graph_view = mg_graph.view(); + auto mg_edge_weight_view = + mg_edge_weights ? std::make_optional((*mg_edge_weights).view()) : std::nullopt; + + raft::random::RngState rng_state(multi_gpu ? handle_->get_comms().get_rank() : 0); + auto d_mis = cugraph::maximal_independent_set( + *handle_, mg_graph_view, rng_state); + + // Test MIS + if (mis_usecase.check_correctness) { + RAFT_CUDA_TRY(cudaDeviceSynchronize()); + std::vector h_mis(d_mis.size()); + raft::update_host(h_mis.data(), d_mis.data(), d_mis.size(), handle_->get_stream()); + + RAFT_CUDA_TRY(cudaDeviceSynchronize()); + + auto vertex_first = mg_graph_view.local_vertex_partition_range_first(); + auto vertex_last = mg_graph_view.local_vertex_partition_range_last(); + + std::for_each(h_mis.begin(), h_mis.end(), [vertex_first, vertex_last](vertex_t v) { + ASSERT_TRUE((v >= vertex_first) && (v < vertex_last)); + }); + + // If a vertex is included in MIS, then none of its neighbor should be + + vertex_t local_vtx_partitoin_size = mg_graph_view.local_vertex_partition_range_size(); + rmm::device_uvector d_total_outgoing_nbrs_included_mis(local_vtx_partitoin_size, + handle_->get_stream()); + + rmm::device_uvector inclusiong_flags(local_vtx_partitoin_size, + handle_->get_stream()); + + thrust::uninitialized_fill(handle_->get_thrust_policy(), + inclusiong_flags.begin(), + inclusiong_flags.end(), + vertex_t{0}); + + thrust::for_each( + handle_->get_thrust_policy(), + d_mis.begin(), + d_mis.end(), + [inclusiong_flags = + raft::device_span(inclusiong_flags.data(), inclusiong_flags.size()), + v_first = mg_graph_view.local_vertex_partition_range_first()] __device__(auto v) { + auto v_offset = v - v_first; + inclusiong_flags[v_offset] = vertex_t{1}; + }); + + RAFT_CUDA_TRY(cudaDeviceSynchronize()); + + // Cache for inclusiong_flags + using GraphViewType = cugraph::graph_view_t; + cugraph::edge_src_property_t src_inclusion_cache(*handle_); + cugraph::edge_dst_property_t dst_inclusion_cache(*handle_); + + if constexpr (multi_gpu) { + src_inclusion_cache = + cugraph::edge_src_property_t(*handle_, mg_graph_view); + dst_inclusion_cache = + cugraph::edge_dst_property_t(*handle_, mg_graph_view); + update_edge_src_property( + *handle_, mg_graph_view, inclusiong_flags.begin(), src_inclusion_cache); + update_edge_dst_property( + *handle_, mg_graph_view, inclusiong_flags.begin(), dst_inclusion_cache); + } + + per_v_transform_reduce_outgoing_e( + *handle_, + mg_graph_view, + multi_gpu ? src_inclusion_cache.view() + : cugraph::detail::edge_major_property_view_t( + inclusiong_flags.data()), + multi_gpu ? dst_inclusion_cache.view() + : cugraph::detail::edge_minor_property_view_t( + inclusiong_flags.data(), vertex_t{0}), + cugraph::edge_dummy_property_t{}.view(), + [] __device__(auto src, auto dst, auto src_included, auto dst_included, auto wt) { + return (src == dst) ? 0 : dst_included; + }, + vertex_t{0}, + cugraph::reduce_op::plus{}, + d_total_outgoing_nbrs_included_mis.begin()); + + RAFT_CUDA_TRY(cudaDeviceSynchronize()); + + std::vector h_total_outgoing_nbrs_included_mis( + d_total_outgoing_nbrs_included_mis.size()); + raft::update_host(h_total_outgoing_nbrs_included_mis.data(), + d_total_outgoing_nbrs_included_mis.data(), + d_total_outgoing_nbrs_included_mis.size(), + handle_->get_stream()); + + RAFT_CUDA_TRY(cudaDeviceSynchronize()); + + { + auto vertex_first = mg_graph_view.local_vertex_partition_range_first(); + auto vertex_last = mg_graph_view.local_vertex_partition_range_last(); + + std::for_each(h_mis.begin(), + h_mis.end(), + [vertex_first, vertex_last, &h_total_outgoing_nbrs_included_mis](vertex_t v) { + ASSERT_TRUE((v >= vertex_first) && (v < vertex_last)) + << v << " is not within vertex parition range" << std::endl; + + ASSERT_TRUE(h_total_outgoing_nbrs_included_mis[v - vertex_first] == 0) + << v << "'s neighbor is included in MIS" << std::endl; + }); + } + } + } + + private: + static std::unique_ptr handle_; +}; + +template +std::unique_ptr Tests_MGMaximalIndependentSet::handle_ = nullptr; + +using Tests_MGMaximalIndependentSet_File = + Tests_MGMaximalIndependentSet; +using Tests_MGMaximalIndependentSet_Rmat = + Tests_MGMaximalIndependentSet; + +TEST_P(Tests_MGMaximalIndependentSet_File, CheckInt32Int32FloatFloat) +{ + run_current_test( + override_File_Usecase_with_cmd_line_arguments(GetParam())); +} + +TEST_P(Tests_MGMaximalIndependentSet_File, CheckInt32Int64FloatFloat) +{ + run_current_test( + override_File_Usecase_with_cmd_line_arguments(GetParam())); +} + +TEST_P(Tests_MGMaximalIndependentSet_File, CheckInt64Int64FloatFloat) +{ + run_current_test( + override_File_Usecase_with_cmd_line_arguments(GetParam())); +} + +TEST_P(Tests_MGMaximalIndependentSet_Rmat, CheckInt32Int32FloatFloat) +{ + run_current_test( + override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); +} + +TEST_P(Tests_MGMaximalIndependentSet_Rmat, CheckInt32Int64FloatFloat) +{ + run_current_test( + override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); +} + +TEST_P(Tests_MGMaximalIndependentSet_Rmat, CheckInt64Int64FloatFloat) +{ + run_current_test( + override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); +} + +INSTANTIATE_TEST_SUITE_P( + file_test, + Tests_MGMaximalIndependentSet_File, + ::testing::Combine(::testing::Values(MaximalIndependentSet_Usecase{false}, + MaximalIndependentSet_Usecase{false}), + ::testing::Values(cugraph::test::File_Usecase("test/datasets/karate.mtx")))); + +INSTANTIATE_TEST_SUITE_P(rmat_small_test, + Tests_MGMaximalIndependentSet_Rmat, + ::testing::Combine(::testing::Values(MaximalIndependentSet_Usecase{false}), + ::testing::Values(cugraph::test::Rmat_Usecase( + 3, 4, 0.57, 0.19, 0.19, 0, true, false)))); + +INSTANTIATE_TEST_SUITE_P( + rmat_benchmark_test, /* note that scale & edge factor can be overridden in benchmarking (with + --gtest_filter to select only the rmat_benchmark_test with a specific + vertex & edge type combination) by command line arguments and do not + include more than one Rmat_Usecase that differ only in scale or edge + factor (to avoid running same benchmarks more than once) */ + Tests_MGMaximalIndependentSet_Rmat, + ::testing::Combine( + ::testing::Values(MaximalIndependentSet_Usecase{false}, MaximalIndependentSet_Usecase{false}), + ::testing::Values(cugraph::test::Rmat_Usecase(20, 32, 0.57, 0.19, 0.19, 0, false, false)))); + +CUGRAPH_MG_TEST_PROGRAM_MAIN() diff --git a/cpp/tests/prims/property_generator.cuh b/cpp/tests/prims/property_generator.cuh index 7084cb124af..24a21c1cb01 100644 --- a/cpp/tests/prims/property_generator.cuh +++ b/cpp/tests/prims/property_generator.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,6 +14,7 @@ * limitations under the License. */ +#pragma once #include #include diff --git a/python/pylibcugraph/pylibcugraph/_cugraph_c/community_algorithms.pxd b/python/pylibcugraph/pylibcugraph/_cugraph_c/community_algorithms.pxd index 57cd9993ac8..67ba43bf611 100644 --- a/python/pylibcugraph/pylibcugraph/_cugraph_c/community_algorithms.pxd +++ b/python/pylibcugraph/pylibcugraph/_cugraph_c/community_algorithms.pxd @@ -34,6 +34,9 @@ from pylibcugraph._cugraph_c.graph_functions cimport ( cugraph_induced_subgraph_result_t, ) +from pylibcugraph._cugraph_c.random cimport ( + cugraph_rng_state_t, +) cdef extern from "cugraph_c/community_algorithms.h": ########################################################################### @@ -139,9 +142,11 @@ cdef extern from "cugraph_c/community_algorithms.h": cdef cugraph_error_code_t \ cugraph_leiden( const cugraph_resource_handle_t* handle, + cugraph_rng_state_t* rng_state, cugraph_graph_t* graph, size_t max_level, double resolution, + double theta, bool_t do_expensive_check, cugraph_hierarchical_clustering_result_t** result, cugraph_error_t** error diff --git a/python/pylibcugraph/pylibcugraph/leiden.pyx b/python/pylibcugraph/pylibcugraph/leiden.pyx index 35d6267e0db..87286234f16 100644 --- a/python/pylibcugraph/pylibcugraph/leiden.pyx +++ b/python/pylibcugraph/pylibcugraph/leiden.pyx @@ -47,12 +47,20 @@ from pylibcugraph.utils cimport ( assert_success, copy_to_cupy_array, ) +from pylibcugraph._cugraph_c.random cimport ( + cugraph_rng_state_t +) +from pylibcugraph.random cimport ( + CuGraphRandomState +) def leiden(ResourceHandle resource_handle, + random_state, _GPUGraph graph, size_t max_level, double resolution, + double theta, bool_t do_expensive_check): """ Compute the modularity optimizing partition of the input graph using the @@ -64,6 +72,11 @@ def leiden(ResourceHandle resource_handle, Handle to the underlying device resources needed for referencing data and running algorithms. + random_state : int , optional + Random state to use when generating samples. Optional argument, + defaults to a hash of process id, time, and hostname. + (See pylibcugraph.random.CuGraphRandomState) + graph : SGGraph or MGGraph The input graph. @@ -79,6 +92,11 @@ def leiden(ResourceHandle resource_handle, communities, lower resolutions lead to fewer larger communities. Defaults to 1. + theta: double + Called theta in the Leiden algorithm, this is used to scale + modularity gain in Leiden refinement phase, to compute + the probability of joining a random leiden community. + do_expensive_check : bool_t If True, performs more extensive tests on the inputs to ensure validitity, at the expense of increased run time. @@ -117,10 +135,16 @@ def leiden(ResourceHandle resource_handle, cdef cugraph_error_code_t error_code cdef cugraph_error_t* error_ptr + cg_rng_state = CuGraphRandomState(resource_handle, random_state) + + cdef cugraph_rng_state_t* rng_state_ptr = cg_rng_state.rng_state_ptr + error_code = cugraph_leiden(c_resource_handle_ptr, + rng_state_ptr, c_graph_ptr, max_level, resolution, + theta, do_expensive_check, &result_ptr, &error_ptr)