Skip to content

Commit

Permalink
[REVIEW] Cugraph dgl block improvements (#3810)
Browse files Browse the repository at this point in the history
This PR fixes: #3784 and speeds up MFG creation by `3.5x` . 

Todo:
- [x] Add tests


Benchmarked on 6_462_743_488 edges with a batch size of `128` on a 1 V100:

Before PR Times:
```
1min 17s
```

After PR Times:
```
22 s
```

See link: https://gist.github.com/VibhuJawa/4852203f2e96de09d84d698af945682d


**Profiling:**

After PR: #3810

<img width="1252" alt="image" src="https://github.com/rapidsai/cugraph/assets/4837571/4cbe5153-4251-4195-9471-c60d11cdf7e9">


<img width="1252" alt="image" src="https://github.com/rapidsai/cugraph/assets/4837571/ad019f47-6ccf-45b2-b866-9a2f4f16bc9b">


Profile of splitting df into tensors : 

<img width="781" alt="image" src="https://github.com/rapidsai/cugraph/assets/4837571/82c401d6-1fca-44da-871f-ebb163a464ba">

Authors:
  - Vibhu Jawa (https://github.com/VibhuJawa)

Approvers:
  - Alex Barghi (https://github.com/alexbarghi-nv)
  - Tingyu Wang (https://github.com/tingyu66)

URL: #3810
  • Loading branch information
VibhuJawa authored Aug 22, 2023
1 parent fa99e34 commit f0d16c1
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 74 deletions.
9 changes: 8 additions & 1 deletion python/cugraph-dgl/cugraph_dgl/dataloading/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,21 @@ def __init__(
self,
total_number_of_nodes: int,
edge_dir: str,
return_type: str = "dgl.Block",
):
if return_type not in ["dgl.Block", "cugraph_dgl.nn.SparseGraph"]:
raise ValueError(
"return_type must be either 'dgl.Block' or \
'cugraph_dgl.nn.SparseGraph' "
)
# TODO: Deprecate `total_number_of_nodes`
# as it is no longer needed
# in the next release
self.total_number_of_nodes = total_number_of_nodes
self.edge_dir = edge_dir
self._current_batch_fn = None
self._input_files = None
self._return_type = return_type

def __len__(self):
return self.num_batches
Expand All @@ -55,7 +62,7 @@ def __getitem__(self, idx: int):
if fn != self._current_batch_fn:
df = _load_sampled_file(dataset_obj=self, fn=fn)
self._current_batches = create_homogeneous_sampled_graphs_from_dataframe(
df, self.edge_dir
sampled_df=df, edge_dir=self.edge_dir, return_type=self._return_type
)
current_offset = idx - batch_offset
return self._current_batches[current_offset]
Expand Down
175 changes: 120 additions & 55 deletions python/cugraph-dgl/cugraph_dgl/dataloading/utils/sampling_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

dgl = import_optional("dgl")
torch = import_optional("torch")
cugraph_dgl = import_optional("cugraph_dgl")


def cast_to_tensor(ser: cudf.Series):
Expand All @@ -40,6 +41,30 @@ def _split_tensor(t, split_indices):
return torch.tensor_split(t, split_indices)


def _get_source_destination_range(sampled_df):
o = sampled_df.groupby(["batch_id", "hop_id"], as_index=True).agg(
{"sources": "max", "destinations": "max"}
)
o.rename(
columns={"sources": "sources_range", "destinations": "destinations_range"},
inplace=True,
)
d = o.to_dict(orient="index")
return d


def _create_split_dict(tensor):
min_value = tensor.min()
max_value = tensor.max()
indices = torch.arange(
start=min_value + 1,
end=max_value + 1,
device=tensor.device,
)
split_dict = {i: {} for i in range(min_value, max_value + 1)}
return split_dict, indices


def _get_renumber_map(df):
map = df["map"]
df.drop(columns=["map"], inplace=True)
Expand All @@ -49,9 +74,12 @@ def _get_renumber_map(df):
renumber_map_batch_indices = map[1 : map_starting_offset - 1].reset_index(drop=True)
renumber_map_batch_indices = renumber_map_batch_indices - map_starting_offset

# Drop all rows with NaN values
df.dropna(axis=0, how="all", inplace=True)
df.reset_index(drop=True, inplace=True)
map_end_offset = map_starting_offset + len(renumber_map)
# We only need to drop rows if the length of dataframe is determined by the map
# that is if map_length > sampled edges length
if map_end_offset == len(df):
df.dropna(axis=0, how="all", inplace=True)
df.reset_index(drop=True, inplace=True)

return df, cast_to_tensor(renumber_map), cast_to_tensor(renumber_map_batch_indices)

Expand All @@ -65,24 +93,16 @@ def _get_tensor_d_from_sampled_df(df):
Returns:
dict: A dictionary of tensors, keyed by batch_id and hop_id.
"""
range_d = _get_source_destination_range(df)
df, renumber_map, renumber_map_batch_indices = _get_renumber_map(df)
batch_id_tensor = cast_to_tensor(df["batch_id"])
batch_id_min = batch_id_tensor.min()
batch_id_max = batch_id_tensor.max()
batch_indices = torch.arange(
start=batch_id_min + 1,
end=batch_id_max + 1,
device=batch_id_tensor.device,
)
# TODO: Fix below
# batch_indices = _get_id_tensor_boundaries(batch_id_tensor)
batch_indices = torch.searchsorted(batch_id_tensor, batch_indices).to("cpu")
split_d = {i: {} for i in range(batch_id_min, batch_id_max + 1)}
split_d, batch_indices = _create_split_dict(batch_id_tensor)
batch_split_indices = torch.searchsorted(batch_id_tensor, batch_indices).to("cpu")

for column in df.columns:
if column != "batch_id":
t = cast_to_tensor(df[column])
split_t = _split_tensor(t, batch_indices)
split_t = _split_tensor(t, batch_split_indices)
for bid, batch_t in zip(split_d.keys(), split_t):
split_d[bid][column] = batch_t

Expand All @@ -91,35 +111,37 @@ def _get_tensor_d_from_sampled_df(df):
split_d[bid]["map"] = batch_t
del df
result_tensor_d = {}
# Cache hop_split_d, hop_indices
hop_split_empty_d, hop_indices = None, None
for batch_id, batch_d in split_d.items():
hop_id_tensor = batch_d["hop_id"]
hop_id_min = hop_id_tensor.min()
hop_id_max = hop_id_tensor.max()
if hop_split_empty_d is None:
hop_split_empty_d, hop_indices = _create_split_dict(hop_id_tensor)

hop_indices = torch.arange(
start=hop_id_min + 1,
end=hop_id_max + 1,
device=hop_id_tensor.device,
)
# TODO: Fix below
# hop_indices = _get_id_tensor_boundaries(hop_id_tensor)
hop_indices = torch.searchsorted(hop_id_tensor, hop_indices).to("cpu")
hop_split_d = {i: {} for i in range(hop_id_min, hop_id_max + 1)}
hop_split_d = {k: {} for k in hop_split_empty_d.keys()}
hop_split_indices = torch.searchsorted(hop_id_tensor, hop_indices).to("cpu")
for column, t in batch_d.items():
if column not in ["hop_id", "map"]:
split_t = _split_tensor(t, hop_indices)
split_t = _split_tensor(t, hop_split_indices)
for hid, ht in zip(hop_split_d.keys(), split_t):
hop_split_d[hid][column] = ht
for hid in hop_split_d.keys():
hop_split_d[hid]["sources_range"] = range_d[(batch_id, hid)][
"sources_range"
]
hop_split_d[hid]["destinations_range"] = range_d[(batch_id, hid)][
"destinations_range"
]

result_tensor_d[batch_id] = hop_split_d
if "map" in batch_d:
result_tensor_d[batch_id]["map"] = batch_d["map"]
result_tensor_d[batch_id]["map"] = batch_d["map"]
return result_tensor_d


def create_homogeneous_sampled_graphs_from_dataframe(
sampled_df: cudf.DataFrame,
edge_dir: str = "in",
return_type: str = "dgl.Block",
):
"""
This helper function creates DGL MFGS for
Expand All @@ -136,85 +158,128 @@ def create_homogeneous_sampled_graphs_from_dataframe(
- output_nodes: The output nodes for the batch.
- graph_per_hop_ls: A list of DGL MFGS for each hop.
"""
if return_type not in ["dgl.Block", "cugraph_dgl.nn.SparseGraph"]:
raise ValueError(
"return_type must be either dgl.Block or cugraph_dgl.nn.SparseGraph"
)

result_tensor_d = _get_tensor_d_from_sampled_df(sampled_df)
del sampled_df
result_mfgs = [
_create_homogeneous_sampled_graphs_from_tensors_perhop(
tensors_batch_d, edge_dir
tensors_batch_d, edge_dir, return_type
)
for tensors_batch_d in result_tensor_d.values()
]
del result_tensor_d
return result_mfgs


def _create_homogeneous_sampled_graphs_from_tensors_perhop(tensors_batch_d, edge_dir):
def _create_homogeneous_sampled_graphs_from_tensors_perhop(
tensors_batch_d, edge_dir, return_type
):
"""
This helper function creates sampled DGL MFGS for
homogeneous graphs from tensors per hop for a single
batch
Args:
tensors_batch_d (dict): A dictionary of tensors, keyed by hop_id.
edge_dir (str): Direction of edges from samples
metagraph (dgl.metagraph): The metagraph for the sampled graph
return_type (str): The type of graph to return
Returns:
tuple: A tuple of three elements:
- input_nodes: The input nodes for the batch.
- output_nodes: The output nodes for the batch.
- graph_per_hop_ls: A list of DGL MFGS for each hop.
- graph_per_hop_ls: A list of MFGS for each hop.
"""
if edge_dir not in ["in", "out"]:
raise ValueError(f"Invalid edge_dir {edge_dir} provided")
if edge_dir == "out":
raise ValueError("Outwards edges not supported yet")
graph_per_hop_ls = []
seednodes = None
seednodes_range = None
for hop_id, tensor_per_hop_d in tensors_batch_d.items():
if hop_id != "map":
block = _create_homogeneous_dgl_block_from_tensor_d(
tensor_per_hop_d, tensors_batch_d["map"], seednodes
if return_type == "dgl.Block":
mfg = _create_homogeneous_dgl_block_from_tensor_d(
tensor_d=tensor_per_hop_d,
renumber_map=tensors_batch_d["map"],
seednodes_range=seednodes_range,
)
elif return_type == "cugraph_dgl.nn.SparseGraph":
mfg = _create_homogeneous_cugraph_dgl_nn_sparse_graph(
tensor_d=tensor_per_hop_d, seednodes_range=seednodes_range
)
else:
raise ValueError(f"Invalid return_type {return_type} provided")
seednodes_range = max(
tensor_per_hop_d["sources_range"],
tensor_per_hop_d["destinations_range"],
)
seednodes = torch.concat(
[tensor_per_hop_d["sources"], tensor_per_hop_d["destinations"]]
)
graph_per_hop_ls.append(block)
graph_per_hop_ls.append(mfg)

# default DGL behavior
if edge_dir == "in":
graph_per_hop_ls.reverse()

input_nodes = graph_per_hop_ls[0].srcdata[dgl.NID]
output_nodes = graph_per_hop_ls[-1].dstdata[dgl.NID]
if return_type == "dgl.Block":
input_nodes = graph_per_hop_ls[0].srcdata[dgl.NID]
output_nodes = graph_per_hop_ls[-1].dstdata[dgl.NID]
else:
map = tensors_batch_d["map"]
input_nodes = map[0 : graph_per_hop_ls[0].num_src_nodes()]
output_nodes = map[0 : graph_per_hop_ls[-1].num_dst_nodes()]
return input_nodes, output_nodes, graph_per_hop_ls


def _create_homogeneous_dgl_block_from_tensor_d(tensor_d, renumber_map, seednodes=None):
def _create_homogeneous_dgl_block_from_tensor_d(
tensor_d,
renumber_map,
seednodes_range=None,
):
rs = tensor_d["sources"]
rd = tensor_d["destinations"]

max_src_nodes = rs.max()
max_dst_nodes = rd.max()
if seednodes is not None:
# If we have isolated vertices
max_src_nodes = tensor_d["sources_range"]
max_dst_nodes = tensor_d["destinations_range"]
if seednodes_range is not None:
# If we have vertices without outgoing edges, then
# sources can be missing from seednodes
# so we add them
# to ensure all the blocks are
# linedup correctly
max_dst_nodes = max(max_dst_nodes, seednodes.max())
# lined up correctly
max_dst_nodes = max(max_dst_nodes, seednodes_range)

data_dict = {("_N", "_E", "_N"): (rs, rd)}
num_src_nodes = {"_N": max_src_nodes.item() + 1}
num_dst_nodes = {"_N": max_dst_nodes.item() + 1}
num_src_nodes = {"_N": max_src_nodes + 1}
num_dst_nodes = {"_N": max_dst_nodes + 1}

block = dgl.create_block(
data_dict=data_dict, num_src_nodes=num_src_nodes, num_dst_nodes=num_dst_nodes
)
if "edge_id" in tensor_d:
block.edata[dgl.EID] = tensor_d["edge_id"]
block.srcdata[dgl.NID] = renumber_map[block.srcnodes()]
block.dstdata[dgl.NID] = renumber_map[block.dstnodes()]
# Below adds run time overhead
block.srcdata[dgl.NID] = renumber_map[0 : max_src_nodes + 1]
block.dstdata[dgl.NID] = renumber_map[0 : max_dst_nodes + 1]
return block


def _create_homogeneous_cugraph_dgl_nn_sparse_graph(tensor_d, seednodes_range):
max_src_nodes = tensor_d["sources_range"]
max_dst_nodes = tensor_d["destinations_range"]
if seednodes_range is not None:
max_dst_nodes = max(max_dst_nodes, seednodes_range)
size = (max_src_nodes + 1, max_dst_nodes + 1)
sparse_graph = cugraph_dgl.nn.SparseGraph(
size=size,
src_ids=tensor_d["sources"],
dst_ids=tensor_d["destinations"],
formats=["csc"],
reduce_memory=True,
)
return sparse_graph


def create_heterogeneous_sampled_graphs_from_dataframe(
sampled_df: cudf.DataFrame,
num_nodes_dict: Dict[str, int],
Expand Down
Loading

0 comments on commit f0d16c1

Please sign in to comment.