Skip to content

Commit

Permalink
[BUG] Critical: Fix cuGraph-PyG Edge Index Renumbering for Single-Edg…
Browse files Browse the repository at this point in the history
…e Graphs (#3605)

The changes for multi-edge graphs to accommodate the new hop id behavior were not properly applied to single-edge graphs.  This PR resolves that.

Authors:
  - Alex Barghi (https://github.com/alexbarghi-nv)

Approvers:
  - Rick Ratzel (https://github.com/rlratzel)

URL: #3605
  • Loading branch information
alexbarghi-nv authored May 26, 2023
1 parent 483d36b commit 24b6f3a
Showing 1 changed file with 19 additions and 10 deletions.
29 changes: 19 additions & 10 deletions python/cugraph-pyg/cugraph_pyg/data/cugraph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,7 +696,7 @@ def _get_vertex_groups_from_sample(

vtypes = cudf.Series(self.__vertex_type_offsets["type"])
if len(vtypes) == 1:
noi_index[vtypes[0]] = nodes_of_interest
noi_index[vtypes.iloc[0]] = nodes_of_interest
else:
noi_type_indices = torch.searchsorted(
torch.as_tensor(self.__vertex_type_offsets["stop"], device="cuda"),
Expand Down Expand Up @@ -788,17 +788,26 @@ def _get_renumbered_edge_groups_from_sample(
t_pyg_type = list(self.__edge_types_to_attrs.values())[0].edge_type
src_type, _, dst_type = t_pyg_type

sources = torch.as_tensor(sampling_results.sources.values, device="cuda")
src_id_table = noi_index[src_type]
src = torch.searchsorted(src_id_table, sources)
row_dict[t_pyg_type] = src
dst_id_table = noi_index[dst_type]
dst_id_map = (
cudf.Series(cupy.asarray(dst_id_table), name="dst")
.reset_index()
.rename(columns={"index": "new_id"})
.set_index("dst")
)
dst = dst_id_map["new_id"].loc[sampling_results.destinations]
col_dict[t_pyg_type] = torch.as_tensor(dst.values, device="cuda")

destinations = torch.as_tensor(
sampling_results.destinations.values, device="cuda"
src_id_table = noi_index[src_type]
src_id_map = (
cudf.Series(cupy.asarray(src_id_table), name="src")
.reset_index()
.rename(columns={"index": "new_id"})
.set_index("src")
)
dst_id_table = noi_index[dst_type]
dst = torch.searchsorted(dst_id_table, destinations)
col_dict[t_pyg_type] = dst
src = src_id_map["new_id"].loc[sampling_results.sources]
row_dict[t_pyg_type] = torch.as_tensor(src.values, device="cuda")

else:
# This will retrieve the single string representation.
# It needs to be converted to a tuple in the for loop below.
Expand Down

0 comments on commit 24b6f3a

Please sign in to comment.