diff --git a/python/cugraph-dgl/cugraph_dgl/dataloading/dataset.py b/python/cugraph-dgl/cugraph_dgl/dataloading/dataset.py index 57f7db3be01..e0d51bcf4cf 100644 --- a/python/cugraph-dgl/cugraph_dgl/dataloading/dataset.py +++ b/python/cugraph-dgl/cugraph_dgl/dataloading/dataset.py @@ -32,7 +32,13 @@ def __init__( self, total_number_of_nodes: int, edge_dir: str, + return_type: str = "dgl.Block", ): + if return_type not in ["dgl.Block", "cugraph_dgl.nn.SparseGraph"]: + raise ValueError( + "return_type must be either 'dgl.Block' or \ + 'cugraph_dgl.nn.SparseGraph' " + ) # TODO: Deprecate `total_number_of_nodes` # as it is no longer needed # in the next release @@ -40,6 +46,7 @@ def __init__( self.edge_dir = edge_dir self._current_batch_fn = None self._input_files = None + self._return_type = return_type def __len__(self): return self.num_batches @@ -55,7 +62,7 @@ def __getitem__(self, idx: int): if fn != self._current_batch_fn: df = _load_sampled_file(dataset_obj=self, fn=fn) self._current_batches = create_homogeneous_sampled_graphs_from_dataframe( - df, self.edge_dir + sampled_df=df, edge_dir=self.edge_dir, return_type=self._return_type ) current_offset = idx - batch_offset return self._current_batches[current_offset] diff --git a/python/cugraph-dgl/cugraph_dgl/dataloading/utils/sampling_helpers.py b/python/cugraph-dgl/cugraph_dgl/dataloading/utils/sampling_helpers.py index 9fc0f6a559b..bdac3b1a323 100644 --- a/python/cugraph-dgl/cugraph_dgl/dataloading/utils/sampling_helpers.py +++ b/python/cugraph-dgl/cugraph_dgl/dataloading/utils/sampling_helpers.py @@ -18,6 +18,7 @@ dgl = import_optional("dgl") torch = import_optional("torch") +cugraph_dgl = import_optional("cugraph_dgl") def cast_to_tensor(ser: cudf.Series): @@ -40,6 +41,30 @@ def _split_tensor(t, split_indices): return torch.tensor_split(t, split_indices) +def _get_source_destination_range(sampled_df): + o = sampled_df.groupby(["batch_id", "hop_id"], as_index=True).agg( + {"sources": "max", "destinations": "max"} + ) + o.rename( + columns={"sources": "sources_range", "destinations": "destinations_range"}, + inplace=True, + ) + d = o.to_dict(orient="index") + return d + + +def _create_split_dict(tensor): + min_value = tensor.min() + max_value = tensor.max() + indices = torch.arange( + start=min_value + 1, + end=max_value + 1, + device=tensor.device, + ) + split_dict = {i: {} for i in range(min_value, max_value + 1)} + return split_dict, indices + + def _get_renumber_map(df): map = df["map"] df.drop(columns=["map"], inplace=True) @@ -49,9 +74,12 @@ def _get_renumber_map(df): renumber_map_batch_indices = map[1 : map_starting_offset - 1].reset_index(drop=True) renumber_map_batch_indices = renumber_map_batch_indices - map_starting_offset - # Drop all rows with NaN values - df.dropna(axis=0, how="all", inplace=True) - df.reset_index(drop=True, inplace=True) + map_end_offset = map_starting_offset + len(renumber_map) + # We only need to drop rows if the length of dataframe is determined by the map + # that is if map_length > sampled edges length + if map_end_offset == len(df): + df.dropna(axis=0, how="all", inplace=True) + df.reset_index(drop=True, inplace=True) return df, cast_to_tensor(renumber_map), cast_to_tensor(renumber_map_batch_indices) @@ -65,24 +93,16 @@ def _get_tensor_d_from_sampled_df(df): Returns: dict: A dictionary of tensors, keyed by batch_id and hop_id. """ + range_d = _get_source_destination_range(df) df, renumber_map, renumber_map_batch_indices = _get_renumber_map(df) batch_id_tensor = cast_to_tensor(df["batch_id"]) - batch_id_min = batch_id_tensor.min() - batch_id_max = batch_id_tensor.max() - batch_indices = torch.arange( - start=batch_id_min + 1, - end=batch_id_max + 1, - device=batch_id_tensor.device, - ) - # TODO: Fix below - # batch_indices = _get_id_tensor_boundaries(batch_id_tensor) - batch_indices = torch.searchsorted(batch_id_tensor, batch_indices).to("cpu") - split_d = {i: {} for i in range(batch_id_min, batch_id_max + 1)} + split_d, batch_indices = _create_split_dict(batch_id_tensor) + batch_split_indices = torch.searchsorted(batch_id_tensor, batch_indices).to("cpu") for column in df.columns: if column != "batch_id": t = cast_to_tensor(df[column]) - split_t = _split_tensor(t, batch_indices) + split_t = _split_tensor(t, batch_split_indices) for bid, batch_t in zip(split_d.keys(), split_t): split_d[bid][column] = batch_t @@ -91,35 +111,37 @@ def _get_tensor_d_from_sampled_df(df): split_d[bid]["map"] = batch_t del df result_tensor_d = {} + # Cache hop_split_d, hop_indices + hop_split_empty_d, hop_indices = None, None for batch_id, batch_d in split_d.items(): hop_id_tensor = batch_d["hop_id"] - hop_id_min = hop_id_tensor.min() - hop_id_max = hop_id_tensor.max() + if hop_split_empty_d is None: + hop_split_empty_d, hop_indices = _create_split_dict(hop_id_tensor) - hop_indices = torch.arange( - start=hop_id_min + 1, - end=hop_id_max + 1, - device=hop_id_tensor.device, - ) - # TODO: Fix below - # hop_indices = _get_id_tensor_boundaries(hop_id_tensor) - hop_indices = torch.searchsorted(hop_id_tensor, hop_indices).to("cpu") - hop_split_d = {i: {} for i in range(hop_id_min, hop_id_max + 1)} + hop_split_d = {k: {} for k in hop_split_empty_d.keys()} + hop_split_indices = torch.searchsorted(hop_id_tensor, hop_indices).to("cpu") for column, t in batch_d.items(): if column not in ["hop_id", "map"]: - split_t = _split_tensor(t, hop_indices) + split_t = _split_tensor(t, hop_split_indices) for hid, ht in zip(hop_split_d.keys(), split_t): hop_split_d[hid][column] = ht + for hid in hop_split_d.keys(): + hop_split_d[hid]["sources_range"] = range_d[(batch_id, hid)][ + "sources_range" + ] + hop_split_d[hid]["destinations_range"] = range_d[(batch_id, hid)][ + "destinations_range" + ] result_tensor_d[batch_id] = hop_split_d - if "map" in batch_d: - result_tensor_d[batch_id]["map"] = batch_d["map"] + result_tensor_d[batch_id]["map"] = batch_d["map"] return result_tensor_d def create_homogeneous_sampled_graphs_from_dataframe( sampled_df: cudf.DataFrame, edge_dir: str = "in", + return_type: str = "dgl.Block", ): """ This helper function creates DGL MFGS for @@ -136,11 +158,16 @@ def create_homogeneous_sampled_graphs_from_dataframe( - output_nodes: The output nodes for the batch. - graph_per_hop_ls: A list of DGL MFGS for each hop. """ + if return_type not in ["dgl.Block", "cugraph_dgl.nn.SparseGraph"]: + raise ValueError( + "return_type must be either dgl.Block or cugraph_dgl.nn.SparseGraph" + ) + result_tensor_d = _get_tensor_d_from_sampled_df(sampled_df) del sampled_df result_mfgs = [ _create_homogeneous_sampled_graphs_from_tensors_perhop( - tensors_batch_d, edge_dir + tensors_batch_d, edge_dir, return_type ) for tensors_batch_d in result_tensor_d.values() ] @@ -148,73 +175,111 @@ def create_homogeneous_sampled_graphs_from_dataframe( return result_mfgs -def _create_homogeneous_sampled_graphs_from_tensors_perhop(tensors_batch_d, edge_dir): +def _create_homogeneous_sampled_graphs_from_tensors_perhop( + tensors_batch_d, edge_dir, return_type +): """ This helper function creates sampled DGL MFGS for homogeneous graphs from tensors per hop for a single batch - Args: tensors_batch_d (dict): A dictionary of tensors, keyed by hop_id. edge_dir (str): Direction of edges from samples + metagraph (dgl.metagraph): The metagraph for the sampled graph + return_type (str): The type of graph to return Returns: tuple: A tuple of three elements: - input_nodes: The input nodes for the batch. - output_nodes: The output nodes for the batch. - - graph_per_hop_ls: A list of DGL MFGS for each hop. + - graph_per_hop_ls: A list of MFGS for each hop. """ if edge_dir not in ["in", "out"]: raise ValueError(f"Invalid edge_dir {edge_dir} provided") if edge_dir == "out": raise ValueError("Outwards edges not supported yet") graph_per_hop_ls = [] - seednodes = None + seednodes_range = None for hop_id, tensor_per_hop_d in tensors_batch_d.items(): if hop_id != "map": - block = _create_homogeneous_dgl_block_from_tensor_d( - tensor_per_hop_d, tensors_batch_d["map"], seednodes + if return_type == "dgl.Block": + mfg = _create_homogeneous_dgl_block_from_tensor_d( + tensor_d=tensor_per_hop_d, + renumber_map=tensors_batch_d["map"], + seednodes_range=seednodes_range, + ) + elif return_type == "cugraph_dgl.nn.SparseGraph": + mfg = _create_homogeneous_cugraph_dgl_nn_sparse_graph( + tensor_d=tensor_per_hop_d, seednodes_range=seednodes_range + ) + else: + raise ValueError(f"Invalid return_type {return_type} provided") + seednodes_range = max( + tensor_per_hop_d["sources_range"], + tensor_per_hop_d["destinations_range"], ) - seednodes = torch.concat( - [tensor_per_hop_d["sources"], tensor_per_hop_d["destinations"]] - ) - graph_per_hop_ls.append(block) + graph_per_hop_ls.append(mfg) # default DGL behavior if edge_dir == "in": graph_per_hop_ls.reverse() - - input_nodes = graph_per_hop_ls[0].srcdata[dgl.NID] - output_nodes = graph_per_hop_ls[-1].dstdata[dgl.NID] + if return_type == "dgl.Block": + input_nodes = graph_per_hop_ls[0].srcdata[dgl.NID] + output_nodes = graph_per_hop_ls[-1].dstdata[dgl.NID] + else: + map = tensors_batch_d["map"] + input_nodes = map[0 : graph_per_hop_ls[0].num_src_nodes()] + output_nodes = map[0 : graph_per_hop_ls[-1].num_dst_nodes()] return input_nodes, output_nodes, graph_per_hop_ls -def _create_homogeneous_dgl_block_from_tensor_d(tensor_d, renumber_map, seednodes=None): +def _create_homogeneous_dgl_block_from_tensor_d( + tensor_d, + renumber_map, + seednodes_range=None, +): rs = tensor_d["sources"] rd = tensor_d["destinations"] - - max_src_nodes = rs.max() - max_dst_nodes = rd.max() - if seednodes is not None: - # If we have isolated vertices + max_src_nodes = tensor_d["sources_range"] + max_dst_nodes = tensor_d["destinations_range"] + if seednodes_range is not None: + # If we have vertices without outgoing edges, then # sources can be missing from seednodes # so we add them # to ensure all the blocks are - # linedup correctly - max_dst_nodes = max(max_dst_nodes, seednodes.max()) + # lined up correctly + max_dst_nodes = max(max_dst_nodes, seednodes_range) data_dict = {("_N", "_E", "_N"): (rs, rd)} - num_src_nodes = {"_N": max_src_nodes.item() + 1} - num_dst_nodes = {"_N": max_dst_nodes.item() + 1} + num_src_nodes = {"_N": max_src_nodes + 1} + num_dst_nodes = {"_N": max_dst_nodes + 1} + block = dgl.create_block( data_dict=data_dict, num_src_nodes=num_src_nodes, num_dst_nodes=num_dst_nodes ) if "edge_id" in tensor_d: block.edata[dgl.EID] = tensor_d["edge_id"] - block.srcdata[dgl.NID] = renumber_map[block.srcnodes()] - block.dstdata[dgl.NID] = renumber_map[block.dstnodes()] + # Below adds run time overhead + block.srcdata[dgl.NID] = renumber_map[0 : max_src_nodes + 1] + block.dstdata[dgl.NID] = renumber_map[0 : max_dst_nodes + 1] return block +def _create_homogeneous_cugraph_dgl_nn_sparse_graph(tensor_d, seednodes_range): + max_src_nodes = tensor_d["sources_range"] + max_dst_nodes = tensor_d["destinations_range"] + if seednodes_range is not None: + max_dst_nodes = max(max_dst_nodes, seednodes_range) + size = (max_src_nodes + 1, max_dst_nodes + 1) + sparse_graph = cugraph_dgl.nn.SparseGraph( + size=size, + src_ids=tensor_d["sources"], + dst_ids=tensor_d["destinations"], + formats=["csc"], + reduce_memory=True, + ) + return sparse_graph + + def create_heterogeneous_sampled_graphs_from_dataframe( sampled_df: cudf.DataFrame, num_nodes_dict: Dict[str, int], diff --git a/python/cugraph-dgl/tests/test_dataset.py b/python/cugraph-dgl/tests/test_dataset.py index a1da77721a3..69d50261e55 100644 --- a/python/cugraph-dgl/tests/test_dataset.py +++ b/python/cugraph-dgl/tests/test_dataset.py @@ -47,20 +47,18 @@ def create_dgl_mfgs(g, seed_nodes, fanout): return sampler.sample_blocks(g, seed_nodes) -def create_cugraph_dgl_homogenous_mfgs(g, seed_nodes, fanout): +def create_cugraph_dgl_homogenous_mfgs(dgl_blocks, return_type): df_ls = [] unique_vertices_ls = [] - for hop_id, fanout in enumerate(reversed(fanout)): - frontier = g.sample_neighbors(seed_nodes, fanout) - # Set include_dst_in_src to match cugraph behavior - block = dgl.to_block(frontier, seed_nodes, include_dst_in_src=False) - block.edata[dgl.EID] = frontier.edata[dgl.EID] - seed_nodes = block.srcdata[dgl.NID] + for hop_id, block in enumerate(reversed(dgl_blocks)): block = block.to("cpu") src, dst, eid = block.edges("all") eid = block.edata[dgl.EID][eid] + + og_src = block.srcdata[dgl.NID][src] + og_dst = block.dstdata[dgl.NID][dst] unique_vertices = pd.concat( - [pd.Series(dst.numpy()), pd.Series(src.numpy())] + [pd.Series(og_dst.numpy()), pd.Series(og_src.numpy())] ).drop_duplicates(keep="first") unique_vertices_ls.append(unique_vertices) df = cudf.DataFrame( @@ -84,23 +82,24 @@ def create_cugraph_dgl_homogenous_mfgs(g, seed_nodes, fanout): # Have to reindex cause map_ser can be of larger length than df df = df.reindex(df.index.union(map_ser.index)) df["map"] = map_ser - return create_homogeneous_sampled_graphs_from_dataframe(df)[0] + return create_homogeneous_sampled_graphs_from_dataframe( + df, return_type=return_type + )[0] +@pytest.mark.parametrize("return_type", ["dgl.Block", "cugraph_dgl.nn.SparseGraph"]) @pytest.mark.parametrize("seed_node", [3, 4, 5]) -def test_homogeneous_sampled_graphs_from_dataframe(seed_node): +def test_homogeneous_sampled_graphs_from_dataframe(return_type, seed_node): g = dgl.graph(([0, 1, 2, 3, 4], [1, 2, 3, 4, 5])) fanout = [1, 1, 1] seed_node = torch.as_tensor([seed_node]) - dgl_seed_nodes, dgl_output_nodes, dgl_mfgs = create_cugraph_dgl_homogenous_mfgs( - g, seed_node, fanout - ) + dgl_seed_nodes, dgl_output_nodes, dgl_mfgs = create_dgl_mfgs(g, seed_node, fanout) ( cugraph_seed_nodes, cugraph_output_nodes, cugraph_mfgs, - ) = create_cugraph_dgl_homogenous_mfgs(g, seed_node, fanout) + ) = create_cugraph_dgl_homogenous_mfgs(dgl_mfgs, return_type=return_type) np.testing.assert_equal( cugraph_seed_nodes.cpu().numpy().copy().sort(), @@ -112,7 +111,18 @@ def test_homogeneous_sampled_graphs_from_dataframe(seed_node): cugraph_output_nodes.cpu().numpy().copy().sort(), ) - for dgl_block, cugraph_dgl_block in zip(dgl_mfgs, cugraph_mfgs): - dgl_df = get_edge_df_from_homogenous_block(dgl_block) - cugraph_dgl_df = get_edge_df_from_homogenous_block(cugraph_dgl_block) - pd.testing.assert_frame_equal(dgl_df, cugraph_dgl_df) + if return_type == "dgl.Block": + for dgl_block, cugraph_dgl_block in zip(dgl_mfgs, cugraph_mfgs): + dgl_df = get_edge_df_from_homogenous_block(dgl_block) + cugraph_dgl_df = get_edge_df_from_homogenous_block(cugraph_dgl_block) + pd.testing.assert_frame_equal(dgl_df, cugraph_dgl_df) + else: + for dgl_block, cugraph_dgl_graph in zip(dgl_mfgs, cugraph_mfgs): + # Can not verify edge ids as they are not + # preserved in cugraph_dgl.nn.SparseGraph + assert dgl_block.num_src_nodes() == cugraph_dgl_graph.num_src_nodes() + assert dgl_block.num_dst_nodes() == cugraph_dgl_graph.num_dst_nodes() + dgl_offsets, dgl_indices, _ = dgl_block.adj_tensors("csc") + cugraph_offsets, cugraph_indices = cugraph_dgl_graph.csc() + assert torch.equal(dgl_offsets.to("cpu"), cugraph_offsets.to("cpu")) + assert torch.equal(dgl_indices.to("cpu"), cugraph_indices.to("cpu")) diff --git a/python/cugraph-dgl/tests/test_utils.py b/python/cugraph-dgl/tests/test_utils.py index fd75b1537b5..740db59ce7f 100644 --- a/python/cugraph-dgl/tests/test_utils.py +++ b/python/cugraph-dgl/tests/test_utils.py @@ -20,11 +20,14 @@ _split_tensor, _get_tensor_d_from_sampled_df, create_homogeneous_sampled_graphs_from_dataframe, + _get_source_destination_range, + _create_homogeneous_cugraph_dgl_nn_sparse_graph, ) from cugraph.utilities.utils import import_optional dgl = import_optional("dgl") torch = import_optional("torch") +cugraph_dgl = import_optional("cugraph_dgl") def test_casting_empty_array(): @@ -140,3 +143,36 @@ def test_create_homogeneous_sampled_graphs_from_dataframe(): de, dd = d_block.edges() assert torch.equal(ce, de) assert torch.equal(cd, dd) + + +def test_get_source_destination_range(): + df = get_dummy_sampled_df() + output_d = _get_source_destination_range(df) + + expected_output = { + (0, 0): {"sources_range": 0, "destinations_range": 1}, + (0, 1): {"sources_range": 1, "destinations_range": 2}, + (1, 0): {"sources_range": 0, "destinations_range": 1}, + (1, 1): {"sources_range": 1, "destinations_range": 2}, + (2, 0): {"sources_range": 0, "destinations_range": 2}, + (2, 1): {"sources_range": 2, "destinations_range": 1}, + } + + assert output_d == expected_output + + +def test__create_homogeneous_cugraph_dgl_nn_sparse_graph(): + tensor_d = { + "sources_range": 1, + "destinations_range": 2, + "sources": torch.as_tensor([0, 0, 1, 1], dtype=torch.int64, device="cuda"), + "destinations": torch.as_tensor([0, 0, 1, 2], dtype=torch.int64, device="cuda"), + } + + seednodes_range = 10 + sparse_graph = _create_homogeneous_cugraph_dgl_nn_sparse_graph( + tensor_d, seednodes_range + ) + assert sparse_graph.num_src_nodes() == 2 + assert sparse_graph.num_dst_nodes() == seednodes_range + 1 + assert isinstance(sparse_graph, cugraph_dgl.nn.SparseGraph)