Skip to content

Commit

Permalink
Re-enable HITS in the python API using the new primitive-based implem…
Browse files Browse the repository at this point in the history
…entation (#1941)

Followup PR to #1930

This PR re-enables HITS in the python API using the new C++ primitive-based implementation. This also refactors the tests to use the benchmark fixture plugin, and adds an additional dataset to read for more diverse comparison to Nx.

![image](https://user-images.githubusercontent.com/3039903/141236018-16557063-9d2a-4fd7-b8c1-789f78958ea7.png)

Authors:
  - Rick Ratzel (https://github.com/rlratzel)

Approvers:
  - Chuck Hastings (https://github.com/ChuckHastings)
  - Brad Rees (https://github.com/BradReesWork)
  - Joseph Nke (https://github.com/jnke2016)

URL: #1941
  • Loading branch information
rlratzel authored Nov 11, 2021
1 parent 0678065 commit 2770c87
Show file tree
Hide file tree
Showing 11 changed files with 307 additions and 140 deletions.
11 changes: 11 additions & 0 deletions cpp/include/cugraph/utilities/cython.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,17 @@ void call_wcc(raft::handle_t const& handle,
graph_container_t const& graph_container,
vertex_t* components);

// Wrapper for calling HITS through a graph container
template <typename vertex_t, typename weight_t>
void call_hits(raft::handle_t const& handle,
graph_container_t const& graph_container,
weight_t* hubs,
weight_t* authorities,
size_t max_iter,
weight_t tolerance,
const weight_t* starting_value,
bool normalized);

// Wrapper for calling graph generator
template <typename vertex_t>
std::unique_ptr<graph_generator_t> call_generate_rmat_edgelist(raft::handle_t const& handle,
Expand Down
114 changes: 114 additions & 0 deletions cpp/src/utilities/cython.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1098,6 +1098,84 @@ void call_wcc(raft::handle_t const& handle,
}
}

// wrapper for HITS:
//
template <typename vertex_t, typename weight_t>
void call_hits(raft::handle_t const& handle,
graph_container_t const& graph_container,
weight_t* hubs,
weight_t* authorities,
size_t max_iter,
weight_t tolerance,
const weight_t* starting_value,
bool normalized)
{
constexpr bool has_initial_hubs_guess{false};
constexpr bool normalize{true};
constexpr bool do_expensive_check{false};
constexpr bool transposed{true};

// FIXME: most of these branches are not currently executed: MG support is not
// yet in the python API, and only int32_t edge types are being used. Consider
// removing these until actually needed.

if (graph_container.is_multi_gpu) {
constexpr bool multi_gpu{true};
if (graph_container.edgeType == numberTypeEnum::int32Type) {
auto graph = detail::create_graph<int32_t, int32_t, weight_t, transposed, multi_gpu>(
handle, graph_container);
cugraph::hits(handle,
graph->view(),
reinterpret_cast<weight_t*>(hubs),
reinterpret_cast<weight_t*>(authorities),
tolerance,
max_iter,
has_initial_hubs_guess,
normalize,
do_expensive_check);
} else if (graph_container.edgeType == numberTypeEnum::int64Type) {
auto graph = detail::create_graph<vertex_t, int64_t, weight_t, transposed, multi_gpu>(
handle, graph_container);
cugraph::hits(handle,
graph->view(),
reinterpret_cast<weight_t*>(hubs),
reinterpret_cast<weight_t*>(authorities),
tolerance,
max_iter,
has_initial_hubs_guess,
normalize,
do_expensive_check);
}
} else {
constexpr bool multi_gpu{false};
if (graph_container.edgeType == numberTypeEnum::int32Type) {
auto graph = detail::create_graph<int32_t, int32_t, weight_t, transposed, multi_gpu>(
handle, graph_container);
cugraph::hits(handle,
graph->view(),
reinterpret_cast<weight_t*>(hubs),
reinterpret_cast<weight_t*>(authorities),
tolerance,
max_iter,
has_initial_hubs_guess,
normalize,
do_expensive_check);
} else if (graph_container.edgeType == numberTypeEnum::int64Type) {
auto graph = detail::create_graph<vertex_t, int64_t, weight_t, transposed, multi_gpu>(
handle, graph_container);
cugraph::hits(handle,
graph->view(),
reinterpret_cast<weight_t*>(hubs),
reinterpret_cast<weight_t*>(authorities),
tolerance,
max_iter,
has_initial_hubs_guess,
normalize,
do_expensive_check);
}
}
}

// wrapper for shuffling:
//
template <typename vertex_t, typename edge_t, typename weight_t>
Expand Down Expand Up @@ -1509,6 +1587,42 @@ template void call_wcc<int64_t, double>(raft::handle_t const& handle,
graph_container_t const& graph_container,
int64_t* components);

template void call_hits<int32_t, float>(raft::handle_t const& handle,
graph_container_t const& graph_container,
float* hubs,
float* authorities,
size_t max_iter,
float tolerance,
const float* starting_value,
bool normalized);

template void call_hits<int32_t, double>(raft::handle_t const& handle,
graph_container_t const& graph_container,
double* hubs,
double* authorities,
size_t max_iter,
double tolerance,
const double* starting_value,
bool normalized);

template void call_hits<int64_t, float>(raft::handle_t const& handle,
graph_container_t const& graph_container,
float* hubs,
float* authorities,
size_t max_iter,
float tolerance,
const float* starting_value,
bool normalized);

template void call_hits<int64_t, double>(raft::handle_t const& handle,
graph_container_t const& graph_container,
double* hubs,
double* authorities,
size_t max_iter,
double tolerance,
const double* starting_value,
bool normalized);

template std::unique_ptr<major_minor_weights_t<int32_t, int32_t, float>> call_shuffle(
raft::handle_t const& handle,
int32_t* edgelist_major_vertices,
Expand Down
1 change: 0 additions & 1 deletion notebooks/link_analysis/HITS.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
"metadata": {},
"source": [
"# HITS\n",
"# Skip notebook test\n",
"\n",
"In this notebook, we will use both NetworkX and cuGraph to compute HITS. \n",
"The NetworkX and cuGraph processes will be interleaved so that each step can be compared.\n",
Expand Down
12 changes: 6 additions & 6 deletions python/cugraph/cugraph/centrality/betweenness_centrality.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def betweenness_centrality(
Parameters
----------
G : cuGraph.Graph or networkx.Graph
The graph can be either directed (DiGraph) or undirected (Graph).
The graph can be either directed (Graph(directed=True)) or undirected.
Weights in the graph are ignored, the current implementation uses
BFS traversals. Use weight parameter if weights need to be considered
(currently not supported)
Expand All @@ -65,8 +65,8 @@ def betweenness_centrality(
normalized : bool, optional
Default is True.
If true, the betweenness values are normalized by
__2 / ((n - 1) * (n - 2))__ for Graphs (undirected), and
__1 / ((n - 1) * (n - 2))__ for DiGraphs (directed graphs)
__2 / ((n - 1) * (n - 2))__ for undirected Graphs, and
__1 / ((n - 1) * (n - 2))__ for directed Graphs
where n is the number of nodes in G.
Normalization will ensure that values are in [0, 1],
this normalization scales for the highest possible value where one
Expand Down Expand Up @@ -170,7 +170,7 @@ def edge_betweenness_centrality(
Parameters
----------
G : cuGraph.Graph or networkx.Graph
The graph can be either directed (DiGraph) or undirected (Graph).
The graph can be either directed (Graph(directed=True)) or undirected.
Weights in the graph are ignored, the current implementation uses
BFS traversals. Use weight parameter if weights need to be considered
(currently not supported)
Expand All @@ -186,8 +186,8 @@ def edge_betweenness_centrality(
normalized : bool, optional
Default is True.
If true, the betweenness values are normalized by
2 / (n * (n - 1)) for Graphs (undirected), and
1 / (n * (n - 1)) for DiGraphs (directed graphs)
2 / (n * (n - 1)) for undirected Graphs, and
1 / (n * (n - 1)) for directed Graphs
where n is the number of nodes in G.
Normalization will ensure that values are in [0, 1],
this normalization scales for the highest possible value where one
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# cython: language_level = 3

from cugraph.centrality.betweenness_centrality cimport betweenness_centrality as c_betweenness_centrality
from cugraph.structure.graph_classes import DiGraph
from cugraph.structure.graph_primtypes cimport *
from libc.stdint cimport uintptr_t
from libcpp cimport bool
Expand Down Expand Up @@ -177,8 +176,7 @@ def batch_betweenness_centrality(input_graph, normalized, endpoints,
comms = Comms.get_comms()
replicated_adjlists = input_graph.batch_adjlists
work_futures = [client.submit(run_mg_work,
(data, type(input_graph)
is DiGraph),
(data, input_graph.is_directed()),
normalized,
endpoints,
weights,
Expand All @@ -197,7 +195,7 @@ def sg_betweenness_centrality(input_graph, normalized, endpoints, weights,
handle = Comms.get_default_handle()
adjlist = input_graph.adjlist
input_data = ((adjlist.offsets, adjlist.indices, adjlist.weights),
type(input_graph) is DiGraph)
input_graph.is_directed())
df = run_internal_work(handle, input_data, normalized, endpoints, weights,
vertices, result_dtype)
return df
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

from cugraph.centrality.betweenness_centrality cimport edge_betweenness_centrality as c_edge_betweenness_centrality
from cugraph.structure import graph_primtypes_wrapper
from cugraph.structure.graph_classes import DiGraph, Graph
from cugraph.structure.graph_primtypes cimport *
from libc.stdint cimport uintptr_t
from libcpp cimport bool
Expand Down Expand Up @@ -166,8 +165,7 @@ def batch_edge_betweenness_centrality(input_graph,
comms = Comms.get_comms()
replicated_adjlists = input_graph.batch_adjlists
work_futures = [client.submit(run_mg_work,
(data, type(input_graph)
is DiGraph),
(data, input_graph.is_directed()),
normalized,
weights,
vertices,
Expand All @@ -188,7 +186,7 @@ def sg_edge_betweenness_centrality(input_graph, normalized, weights,
handle = Comms.get_default_handle()
adjlist = input_graph.adjlist
input_data = ((adjlist.offsets, adjlist.indices, adjlist.weights),
type(input_graph) is DiGraph)
input_graph.is_directed())
df = run_internal_work(handle, input_data, normalized, weights,
vertices, result_dtype)
return df
Expand Down
25 changes: 11 additions & 14 deletions python/cugraph/cugraph/link_analysis/hits.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# cython: profile=False
# distutils: language = c++
# cython: embedsignature = True
# cython: language_level = 3

from cugraph.structure.graph_primtypes cimport *
from libcpp cimport bool

from cugraph.structure.graph_utilities cimport graph_container_t
from cugraph.raft.common.handle cimport handle_t

cdef extern from "cugraph/algorithms.hpp" namespace "cugraph::gunrock":

cdef void hits[VT,ET,WT](
const GraphCSRView[VT,ET,WT] &graph,
cdef extern from "cugraph/utilities/cython.hpp" namespace "cugraph::cython":
cdef void call_hits[vertex_t,weight_t](
const handle_t &handle,
const graph_container_t &g,
weight_t *hubs,
weight_t *authorities,
int max_iter,
WT tolerance,
const WT *starting_value,
bool normalized,
WT *hubs,
WT *authorities) except +
weight_t tolerance,
const weight_t *starting_value,
bool normalized) except +
14 changes: 5 additions & 9 deletions python/cugraph/cugraph/link_analysis/hits.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# from cugraph.link_analysis import hits_wrapper

# from cugraph.utilities import (ensure_cugraph_obj_for_nx,
# df_score_to_dictionary,
# )
from cugraph.link_analysis import hits_wrapper
from cugraph.utilities import (ensure_cugraph_obj_for_nx,
df_score_to_dictionary,
)


def hits(G, max_iter=100, tol=1.0e-5, nstart=None, normalized=True):
Expand Down Expand Up @@ -76,10 +75,9 @@ def hits(G, max_iter=100, tol=1.0e-5, nstart=None, normalized=True):
>>> hits = cugraph.hits(G, max_iter = 50)
"""

"""
G, isNx = ensure_cugraph_obj_for_nx(G)

df = hits_wrapper.hits(G, max_iter, tol) # noqa: F821
df = hits_wrapper.hits(G, max_iter, tol)

if G.renumbered:
df = G.unrenumber(df, "vertex")
Expand All @@ -91,5 +89,3 @@ def hits(G, max_iter=100, tol=1.0e-5, nstart=None, normalized=True):
df = (d1, d2)

return df
"""
raise NotImplementedError("Temporarily disabled. New version in 21.12")
Loading

0 comments on commit 2770c87

Please sign in to comment.