From 1240df9453ecf1f57133abbcff3fb2c4921bf4ca Mon Sep 17 00:00:00 2001 From: Chang Liu Date: Thu, 17 Oct 2024 14:59:32 -0700 Subject: [PATCH 1/6] Add example --- examples/distributed/wholegraph/README | 62 +++ .../distributed/wholegraph/benchmark_data.py | 229 ++++++++++ .../distributed/wholegraph/feature_store.py | 116 +++++ .../distributed/wholegraph/graph_store.py | 226 ++++++++++ .../nv_distributed_graph/__init__.py | 4 + .../nv_distributed_graph/dist_graph.py | 92 ++++ .../nv_distributed_graph/dist_shmem.py | 131 ++++++ .../nv_distributed_graph/dist_tensor.py | 424 ++++++++++++++++++ .../nv_distributed_graph/wholegraph.py | 250 +++++++++++ .../papers100m_dist_wholegraph_nc.py | 194 ++++++++ 10 files changed, 1728 insertions(+) create mode 100644 examples/distributed/wholegraph/README create mode 100644 examples/distributed/wholegraph/benchmark_data.py create mode 100644 examples/distributed/wholegraph/feature_store.py create mode 100644 examples/distributed/wholegraph/graph_store.py create mode 100644 examples/distributed/wholegraph/nv_distributed_graph/__init__.py create mode 100644 examples/distributed/wholegraph/nv_distributed_graph/dist_graph.py create mode 100644 examples/distributed/wholegraph/nv_distributed_graph/dist_shmem.py create mode 100644 examples/distributed/wholegraph/nv_distributed_graph/dist_tensor.py create mode 100644 examples/distributed/wholegraph/nv_distributed_graph/wholegraph.py create mode 100644 examples/distributed/wholegraph/papers100m_dist_wholegraph_nc.py diff --git a/examples/distributed/wholegraph/README b/examples/distributed/wholegraph/README new file mode 100644 index 000000000000..f10147c64b05 --- /dev/null +++ b/examples/distributed/wholegraph/README @@ -0,0 +1,62 @@ +# Using NVIDIA WholeGraph Library for Distributed Training with PyG + +**[RAPIDS WholeGraph](https://github.com/rapidsai/wholegraph)** +NVIDIA WholeGraph is designed to optimize the training of Graph Neural Networks (GNNs) that are often constrained by data loading operations. It provides an underlying storage structure, called WholeMemory, which efficiently manages data storage/communication across disk, RAM, and device memory by leveraging NVIDIA GPUs and communication libraries like NCCL/NVSHMEM. + +WholeGraph is a low-level graph storage library, integrated into and able to work alongside cuGraph, that directly provides an efficient feature and graph store with associated primitive operations (e.g., GPU-accelerated fast embedding retrieval and graph sampling). It is specifically optimized for NVLink systems, including DGX, MGX, and GH/GB200 machine or clusters. + +This example demonstrates how to use WholeGraph to easily distribute the graph and feature store to pinned-host memory for fast GPU UVA access (see the DistTensor class), eliminating the need for manual graph partitioning or any custom third-party launch scripts. WholeGraph seamlessly integrates with PyTorch's Distributed Data Parallel (DDP) setup and works with standard distributed job launchers such as torchrun, mpirun, or srun. + +## Requirements + +- **PyTorch**: `>= 2.0` +- **PyTorch Geometric**: `>= 2.0.0` +- **WholeGraph**: `>= 24.02` +- **NVIDIA GPU(s)** + +## Environment Setup + +```bash + pip install pylibwholegraph-cu12 +``` + +## Sinlge/Multi-GPU Run + +Using PyTorch torchrun elastic launcher: +``` +torchrun papers100m_dist_wholegraph_nc.py +``` +or, using multi-GPUs if applicable: +``` +torchrun --nnodes 1 --nproc-per-node papers100m_dist_wholegraph_nc.py +``` + +## Distributed (multi-node) Run + +For example, let's use the slurm launcher here: + +``` +srun -N --ntasks-per-node= python papers100m_dist_wholegraph_nc.py +``` + +Note the above command line setting is simplified for demonstration purposes. For more details, please refer to this [sbatch script](https://github.com/chang-l/pytorch_geometric/blob/master/examples/multi_gpu/distributed_sampling_multinode.sbatch), as cluster setups may vary. + + +## Benchmark Run + +The benchmark script is similar to the above example but includes a `--mode` command-line argument, allowing users to easily compare PyG's native features/graph store (`torch_geometric.data.Data` and `torch_geometric.data.HeteroData`) with the WholeMemory-based feature store and graph store, shown in this example. It performs a node classification task on the `ogbn-products` dataset. + +### PyG baseline +``` +torchrun --nnodes 1 --nproc-per-node benchmark_data.py --mode baseline +``` + +### WholeGraph FeatureStore integration (UVA for feature store access) +``` +torchrun --nnodes 1 --nproc-per-node benchmark_data.py --mode UVA-features +``` + +### WholeGraph FeatureStore + GraphStore (UVA for feature and graph store access) +``` +torchrun --nnodes 1 --nproc-per-node benchmark_data.py --mode UVA +``` \ No newline at end of file diff --git a/examples/distributed/wholegraph/benchmark_data.py b/examples/distributed/wholegraph/benchmark_data.py new file mode 100644 index 000000000000..ee5fa5dbe94a --- /dev/null +++ b/examples/distributed/wholegraph/benchmark_data.py @@ -0,0 +1,229 @@ +"""Multi-node multi-GPU example on ogbn-papers100m. + +Example way to run using srun: +srun -l -N --ntasks-per-node= \ +--container-name=cont --container-image= \ +--container-mounts=/ogb-papers100m/:/workspace/dataset +python3 path_to_script.py +""" +import os +import time +from typing import Optional +import argparse + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from ogb.nodeproppred import PygNodePropPredDataset +from torch.nn.parallel import DistributedDataParallel +from torchmetrics import Accuracy + +from torch_geometric.loader import NodeLoader, NeighborLoader +from torch_geometric.nn import GCN + +from torch_geometric.sampler import BaseSampler + +from nv_distributed_graph import dist_shmem +from feature_store import WholeGraphFeatureStore +from graph_store import WholeGraphGraphStore + +class WholeGraphSampler(BaseSampler): + r""" + A naive sampler class for WholeGraph graph storage that only supports uniform node-based sampling on homogeneous graph. + """ + from torch_geometric.sampler import SamplerOutput, NodeSamplerInput + def __init__( + self, + graph: WholeGraphGraphStore, + num_neighbors, + ): + import pylibwholegraph.torch as wgth + + self.num_neighbors = num_neighbors + self.wg_sampler = wgth.GraphStructure() + row_indx, col_ptrs, _ = graph.csc() + self.wg_sampler.set_csr_graph(col_ptrs._tensor, row_indx._tensor) + + def sample_from_nodes( + self, + inputs: NodeSamplerInput + ) -> SamplerOutput: + r""" + Sample subgraphs from the given nodes based on uniform node-based sampling. + """ + seed = inputs.node.cuda(non_blocking=True) # WholeGraph Sampler needs all seeds on device + WG_SampleOutput = self.wg_sampler.multilayer_sample_without_replacement(seed, self.num_neighbors, None) + out = WholeGraphGraphStore.create_pyg_subgraph(WG_SampleOutput) + out.metadata = (inputs.input_id, inputs.time) + return out + +def run(world_size, rank, local_rank, device, mode): + wall_clock_start = time.perf_counter() + + # Will query the runtime environment for `MASTER_ADDR` and `MASTER_PORT`. + # Make sure, those are set! + dist.init_process_group('nccl', world_size=world_size, rank=rank) + dist_shmem.init_process_group_per_node() + + # Load the dataset in the local root process and share it with local ranks + if dist_shmem.get_local_rank() == 0: + dataset = PygNodePropPredDataset(name='ogbn-products', root='/workspace') + else: + dataset = None + dataset = dist_shmem.to_shmem(dataset) # move dataset to shmem + + split_idx = dataset.get_idx_split() + split_idx['train'] = split_idx['train'].split( + split_idx['train'].size(0) // world_size, dim=0)[rank].clone() + split_idx['valid'] = split_idx['valid'].split( + split_idx['valid'].size(0) // world_size, dim=0)[rank].clone() + split_idx['test'] = split_idx['test'].split( + split_idx['test'].size(0) // world_size, dim=0)[rank].clone() + data = dataset[0] + num_features = dataset.num_features + num_classes = dataset.num_classes + + if mode == 'baseline': + data = data + kwargs = dict( + data=data, + batch_size=1024, + num_neighbors=[30, 30], + num_workers=4, + ) + train_loader = NeighborLoader( + input_nodes=split_idx['train'], + shuffle=True, + drop_last=True, + **kwargs, + ) + val_loader = NeighborLoader(input_nodes=split_idx['valid'], **kwargs) + test_loader = NeighborLoader(input_nodes=split_idx['test'], **kwargs) + + elif mode == 'UVA-features': + feature_store = WholeGraphFeatureStore(pyg_data=data) + graph_store = WholeGraphGraphStore(pyg_data=data, format='pyg') + data = (feature_store, graph_store) + kwargs = dict( + data=data, + batch_size=1024, + num_neighbors=[30, 30], + num_workers=4, + filter_per_worker=False, # WholeGraph feature fetching is not fork-safe + ) + train_loader = NeighborLoader( + input_nodes=split_idx['train'], + shuffle=True, + drop_last=True, + **kwargs, + ) + val_loader = NeighborLoader(input_nodes=split_idx['valid'], **kwargs) + test_loader = NeighborLoader(input_nodes=split_idx['test'], **kwargs) + + elif mode == 'UVA': + feature_store = WholeGraphFeatureStore(pyg_data=data) + graph_store = WholeGraphGraphStore(pyg_data=data) + data = (feature_store, graph_store) + kwargs = dict( + data=data, + batch_size=1024, + num_workers=0, # with wholegraph sampler you don't need workers + filter_per_worker=False, # WholeGraph feature fetching is not fork-safe + ) + node_sampler = WholeGraphSampler( + graph_store, + num_neighbors=[30, 30], + ) + train_loader = NodeLoader( + input_nodes=split_idx['train'], + node_sampler=node_sampler, + shuffle=True, + drop_last=True, + **kwargs, + ) + val_loader = NodeLoader(input_nodes=split_idx['valid'], node_sampler=node_sampler, **kwargs) + test_loader = NodeLoader(input_nodes=split_idx['test'], node_sampler=node_sampler, **kwargs) + + eval_steps = 1000 + model = GCN(num_features, 256, 2, num_classes) + acc = Accuracy(task="multiclass", num_classes=num_classes).to(device) + model = DistributedDataParallel(model.to(device), device_ids=[local_rank]) + optimizer = torch.optim.Adam(model.parameters(), lr=0.001, + weight_decay=5e-4) + + if rank == 0: + prep_time = round(time.perf_counter() - wall_clock_start, 2) + print("Total time before training begins (prep_time)=", prep_time, + "seconds") + print("Beginning training...") + + for epoch in range(1, 21): + dist.barrier() + start = time.time() + model.train() + for i, batch in enumerate(train_loader): + batch = batch.to(device) + optimizer.zero_grad() + y = batch.y[:batch.batch_size].view(-1).to(torch.long) + out = model(batch.x, batch.edge_index)[:batch.batch_size] + loss = F.cross_entropy(out, y) + loss.backward() + optimizer.step() + if rank == 0 and i % 100 == 0: + print(f'Epoch: {epoch:02d}, Iteration: {i}, Loss: {loss:.4f}') + + # Profile run: + # We synchronize before barrier to flush GPU OPs first, + # then adding barrier to sync CPUs to find max train time among all ranks. + torch.cuda.synchronize() + dist.barrier() + epoch_end = time.time() + + @torch.no_grad() + def test(loader: NodeLoader, num_steps: Optional[int] = None): + model.eval() + for j, batch in enumerate(loader): + if num_steps is not None and j >= num_steps: + break + batch = batch.to(device) + out = model(batch.x, batch.edge_index)[:batch.batch_size] + y = batch.y[:batch.batch_size].view(-1).to(torch.long) + acc(out, y) + acc_sum = acc.compute() + return acc_sum + + eval_acc = test(val_loader, num_steps=eval_steps) + if rank == 0: + print(f"Val Accuracy: {eval_acc:.4f}%", ) + print( + f"Epoch {epoch:05d} | " + f"Accuracy {eval_acc:.4f} | " + f"Time {epoch_end - start:.2f}" + ) + + acc.reset() + dist.barrier() + + test_acc = test(test_loader) + if rank == 0: + print(f"Test Accuracy: {test_acc:.4f}%", ) + dist.destroy_process_group() if dist.is_initialized() else None + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--mode', type=str, default='baseline', choices=['baseline', 'UVA-features', 'UVA']) + args = parser.parse_args() + + # Get the world size from the WORLD_SIZE variable or directly from SLURM: + world_size = int( + os.environ.get('WORLD_SIZE', os.environ.get('SLURM_NTASKS'))) + # Likewise for RANK and LOCAL_RANK: + rank = int(os.environ.get('RANK', os.environ.get('SLURM_PROCID'))) + local_rank = int( + os.environ.get('LOCAL_RANK', os.environ.get('SLURM_LOCALID'))) + + assert torch.cuda.is_available() + device = torch.device(local_rank) + torch.cuda.set_device(device) + run(world_size, rank, local_rank, device, args.mode) + diff --git a/examples/distributed/wholegraph/feature_store.py b/examples/distributed/wholegraph/feature_store.py new file mode 100644 index 000000000000..e9778d21d041 --- /dev/null +++ b/examples/distributed/wholegraph/feature_store.py @@ -0,0 +1,116 @@ +from typing import Optional, Union + +import torch +import torch.distributed as dist + +import torch_geometric +from nv_distributed_graph import DistTensor, DistEmbedding, dist_shmem, nvlink_network + +from torch_geometric.data.feature_store import FeatureStore, TensorAttr + +class WholeGraphFeatureStore(FeatureStore): + r""" A high-performance, UVA-enabled, and multi-GPU/multi-node friendly feature store, powered by WholeGraph library. + It is compatible with PyG's FeatureStore class and supports both homogeneous and heterogeneous graph data types. + + Args: + pyg_data (torch_geometric.data.Data or torch_geometric.data.HeteroData): The input PyG graph data. + + Attributes: + _store (dict): A dictionary to hold the feature embeddings. + backend (str): Using 'nccl' or 'vmm' backend for inter-GPU communication if applicable. + + Methods: + _put_tensor(tensor, attr): + Puts a tensor into the feature store. + _get_tensor(attr): + Retrieves a tensor from the feature store with a given set of indexes. + _remove_tensor(attr): + Not yet implemented; intended for compatibility with PyG's FeatureStore class. + _get_tensor_size(attr): + Returns the size of a tensor in the feature store. + get_all_tensor_attrs(): + Obtains all feature attributes stored in the feature store. + """ + def __init__(self, pyg_data): + r"""Initializes the WholeGraphFeatureStore class and loads features from torch_geometric.data.Data/HeteroData.""" + super().__init__() + self._store = {} # A dictionary of tuple to hold the feature embeddings + + if dist_shmem.get_local_rank() == dist.get_rank(): + self.backend = 'vmm' + else: + self.backend = 'vmm' if nvlink_network() else 'nccl' + + if isinstance(pyg_data, torch_geometric.data.Data): + self.put_tensor(pyg_data['x'], group_name=None, attr_name='x', index=None) + self.put_tensor(pyg_data['y'], group_name=None, attr_name='y', index=None) + + elif isinstance(pyg_data, torch_geometric.data.HeteroData): # if HeteroData, we need to handle differently + for group_name, group in pyg_data.node_items(): + for attr_name in group: + if group.is_node_attr(attr_name) and attr_name in {'x', 'y'}: + self.put_tensor(pyg_data[group_name][attr_name], group_name=group_name, attr_name=attr_name, index=None) + # This is a hack for MAG240M dataset, to add node features for 'institution' and 'author' nodes. + # This should not be presented in the upstream code. + elif attr_name == 'num_nodes': + feature_dim = 768 + num_nodes = group[attr_name] + shape=[num_nodes, feature_dim] + self[group_name, 'x', None] = DistEmbedding(shape=shape, dtype=torch.float16, device="cpu", backend=self.backend) + else: + raise TypeError("Expected pyg_data to be of type torch_geometric.data.Data or torch_geometric.data.HeteroData.") + + def _put_tensor(self, tensor: torch.Tensor, attr): + """ + Creates and stores features (either DistTensor or DistEmbedding) from the given tensor, + using a key derived from the group and attribute name. + + Args: + tensor (torch.Tensor): The tensor to be passed to the feature store. + attr: PyG's TensorAttr to fully specify each feature store. + """ + key = (attr.group_name, attr.attr_name) + out = self._store.get(key) + if out is not None and attr.index is not None: + out[attr.index] = tensor + else: + assert attr.index is None + if tensor.dim() == 1: + # No need to unsqueeze if WholeGraph fix this https://github.com/rapidsai/wholegraph/pull/229 + self._store[key] = DistTensor(tensor.unsqueeze(1), device="cpu", backend=self.backend) + else: + self._store[key] = DistEmbedding(tensor, device="cpu", backend=self.backend) + return True + + def _get_tensor(self, attr) -> Optional[Union[torch.Tensor, DistTensor, DistEmbedding]]: + """ + Retrieves a tensor based on the provided attribute. + + Args: + attr: An object containing the necessary attributes to fetch the tensor. + + Returns: + A tensor which can be of type torch.Tensor, DistTensor, or DistEmbedding, or None if not found. + """ + + key = (attr.group_name, attr.attr_name) + tensor = self._store.get(key) + if tensor is not None: + if attr.index is not None: + output = tensor[attr.index] + return output + else: + return tensor + return None + + def _remove_tensor(self, attr): + pass + + def _get_tensor_size(self, attr): + return self._get_tensor(attr).shape + + def get_all_tensor_attrs(self): + r"""Obtains all feature attributes stored in `Data`.""" + return [ + TensorAttr(group_name=group, attr_name=name) for group, name in self._store.keys() + ] \ No newline at end of file diff --git a/examples/distributed/wholegraph/graph_store.py b/examples/distributed/wholegraph/graph_store.py new file mode 100644 index 000000000000..ac11c7d304c9 --- /dev/null +++ b/examples/distributed/wholegraph/graph_store.py @@ -0,0 +1,226 @@ +from typing import List, Optional, Tuple +from dataclasses import dataclass + +import torch.distributed as dist + +import torch_geometric +from torch_geometric.data.graph_store import GraphStore, EdgeAttr, EdgeLayout +from torch_geometric.sampler import SamplerOutput + +from torch_geometric.typing import EdgeType + +from nv_distributed_graph import DistGraphCSC, dist_shmem, nvlink_network + +@dataclass +class WholeGraphEdgeAttr(EdgeAttr): + r"""Edge attribute class for WholeGraph GraphStore enforcing layout to be CSC.""" + def __init__( + self, + edge_type: Optional[EdgeType] = None, # use string to represent edge type for simplicity + is_sorted: bool = False, + size: Optional[Tuple[int, int]] = None, + ): + layout = EdgeLayout.CSC # Enforce CSC layout for WholeGraph for now + super().__init__(edge_type, layout, is_sorted, size) + +class WholeGraphGraphStore(GraphStore): + r""" A high-performance, UVA-enabled, and multi-GPU/multi-node friendly graph store, powered by WholeGraph library. + It is compatible with PyG's GraphStore base class and supports both homogeneous and heterogeneous graph data types. + + Args: + pyg_data (torch_geometric.data.Data or torch_geometric.data.HeteroData): The input PyG graph data. + format (str): The underlying graph format to use. Default is 'wholegraph'. + + + """ + def __init__(self, pyg_data, format='wholegraph'): + super().__init__(edge_attr_cls=WholeGraphEdgeAttr) + self._g = {} # for simplicy, _g is a dictionary of DistGraphCSC to hold the graph structure data for each type + + if format == 'wholegraph': + pinned_shared = False + if dist_shmem.get_local_rank() == dist.get_rank(): + backend = 'vmm' + else: + backend = 'vmm' if nvlink_network() else 'nccl' + elif format == 'pyg': + pinned_shared = True + backend = None # backend is a no-op for pyg format + else: + raise ValueError("Unsupported underlying graph format") + + if isinstance(pyg_data, torch_geometric.data.Data): + # issue: this will crash: pyg_data.get_all_edge_attrs()[0] if pyg_data is a torch sparse csr + # walkaround: + if 'adj_t' not in pyg_data: + row, col = None, None + if dist_shmem.get_local_rank() == 0: + row, col, _ = pyg_data.csc() # discard permutation for now + row = dist_shmem.to_shmem(row) + col = dist_shmem.to_shmem(col) + size = pyg_data.size() + else: + # issue: it wont work if adj_t is a SparseTensor + col = pyg_data.adj_t.crow_indices() + row = pyg_data.adj_t.col_indices() + size = pyg_data.adj_t.size()[::-1] + + self.num_nodes = pyg_data.num_nodes + graph = DistGraphCSC( + col, + row, + device="cpu", + backend=backend, + pinned_shared=pinned_shared, + ) + self.put_adj_t(graph, size=size) + + elif isinstance(pyg_data, torch_geometric.data.HeteroData): # hetero graph + # issue: this will crash: pyg_data.get_all_edge_attrs()[0] if pyg_data is a torch sparse csr + # walkaround: + self.num_nodes = pyg_data.num_nodes + for edge_type, edge_store in pyg_data.edge_items(): + if 'adj_t' not in edge_store: + row, col = None, None + if dist_shmem.get_local_rank() == 0: + row, col, _ = edge_store.csc() # discard permutation for now + row = dist_shmem.to_shmem(row) + col = dist_shmem.to_shmem(col) + size = edge_store.size() + else: + # issue: this will also if adj_t is a SparseTensor + col = edge_store.adj_t.crow_indices() + row = edge_store.adj_t.col_indices() + size = edge_store.adj_t.size()[::-1] + graph = DistGraphCSC( + col, + row, + device="cpu", + backend=backend, + pinned_shared=pinned_shared, + ) + self.put_adj_t(graph, edge_type=edge_type, size=size) + + def put_adj_t(self, adj_t: DistGraphCSC, *args, **kwargs) -> bool: + r"""Synchronously adds an :obj:`edge_index` tuple to the + :class:`GraphStore`. + Returns whether insertion was successful. + + Args: + edge_index (Tuple[torch.Tensor, torch.Tensor]): The + :obj:`edge_index` tuple in a format specified in + :class:`EdgeAttr`. + *args: Arguments passed to :class:`EdgeAttr`. + **kwargs: Keyword arguments passed to :class:`EdgeAttr`. + """ + edge_attr = self._edge_attr_cls.cast(*args, **kwargs) + return self._put_adj_t(adj_t, edge_attr) + + def get_adj_t(self, *args, **kwargs) -> DistGraphCSC: + edge_attr = self._edge_attr_cls.cast(*args, **kwargs) + graph_adj_t = self._get_adj_t(edge_attr) + if graph_adj_t is None: + raise KeyError(f"'adj_t' for '{edge_attr}' not found") + return graph_adj_t + + def _put_adj_t(self, adj_t: DistGraphCSC, edge_attr: WholeGraphEdgeAttr) -> bool: + if not hasattr(self, '_edge_attrs'): + self._edge_attrs = {} + self._edge_attrs[edge_attr.edge_type] = edge_attr + + self._g[edge_attr.edge_type] = adj_t + if edge_attr.size is None: + # Hopefully, the size is already set beforehand by the input, edge_attr + # Todo: DistGraphCSC does not support size attribute, need to implement it + edge_attr.size = adj_t.size + return True + + def _get_adj_t(self, edge_attr: WholeGraphEdgeAttr) -> Optional[DistGraphCSC]: + store = self._g.get(edge_attr.edge_type) + edge_attrs = getattr(self, '_edge_attrs', {}) + edge_attr = edge_attrs[edge_attr.edge_type] + if edge_attr.size is None: + # Hopefully, the size is already set beforehand by the input, edge_attr + # Todo: DistGraphCSC does not support size attribute, need to implement it + edge_attr.size = store.size # Modify in-place. + return store + + def _get_edge_index(): + pass + + def _put_edge_index(): + pass + + def _remove_edge_index(): + pass + + def __getitem__(self, key: WholeGraphEdgeAttr): + return self.get_adj_t(key) + + def __setitem__(self, key: WholeGraphEdgeAttr, value: DistGraphCSC): + self.put_adj_t(value, key) + + def get_all_edge_attrs(self) -> List[WholeGraphEdgeAttr]: + edge_attrs = getattr(self, '_edge_attrs', {}) + for key, store in self._g.items(): + if key not in edge_attrs: + edge_attrs[key] = WholeGraphEdgeAttr( + key, size=store.size) + return list(edge_attrs.values()) + + def csc(self): + # Define this method to be compatible with pyg native neighbor sampler (if used) see: sampler/neighbour_sampler.py:L222 and L263 + if not self.is_hetero: + key = self.get_all_edge_attrs()[0] + store = self._get_adj_t(key) + return store.row_indx, store.col_ptrs, None # no permutation vector + else: + row_dict = {} + col_dict = {} + for edge_attr in self.get_all_edge_attrs(): + store = self._get_adj_t(edge_attr) + row_dict[edge_attr.edge_type] = store.row_indx + col_dict[edge_attr.edge_type] = store.col_ptrs + return row_dict, col_dict, None + + @property + def is_hetero(self): + if len(self._g) > 1: + return True + return False + + @staticmethod + def create_pyg_subgraph(WG_SampleOutput) -> Tuple: + # PyG_SampleOutput (node, row, col, edge, batch...): + # node (torch.Tensor): The sampled nodes in the original graph. + # row (torch.Tensor): The source node indices of the sampled subgraph. + # Indices must be within {0, ..., num_nodes - 1} where num_nodes is the number of nodes in sampled graph. + # col (torch.Tensor): The destination node indices of the sampled subgraph. Indices must be within {0, ..., num_nodes - 1} + # edge (torch.Tensor, optional): The sampled edges in the original graph. (for obtain edge features from the original graph) + # batch (torch.Tensor, optional): The vector to identify the seed node for each sampled node in case of disjoint subgraph + # sampling per seed node. (None) + # num_sampled_nodes (List[int], optional): The number of sampled nodes per hop. + # num_sampled_edges (List[int], optional): The number of sampled edges per hop. + sampled_nodes_list, edge_indice_list, csr_row_ptr_list, csr_col_ind_list = WG_SampleOutput + num_sampled_nodes = [] + node = sampled_nodes_list[0] + + for hop in range(len(sampled_nodes_list)-1): + sampled_nodes = len(sampled_nodes_list[hop]) - len(sampled_nodes_list[hop+1]) + num_sampled_nodes.append(sampled_nodes) + num_sampled_nodes.append(len(sampled_nodes_list[-1])) + num_sampled_nodes.reverse() + + layers = len(edge_indice_list) + num_sampled_edges = [len(csr_col_ind_list[-1])] + # Loop in reverse order, starting from the second last layer + for layer in range(layers - 2, -1, -1): + num_sampled_edges.append(len(csr_col_ind_list[layer] - len(csr_col_ind_list[layer + 1]))) + + row = csr_col_ind_list[0] # rows + col = edge_indice_list[0][1] # dst node + + edge = None + batch = None + out = node, row, col, edge, batch, num_sampled_nodes, num_sampled_edges + return SamplerOutput.cast(out) diff --git a/examples/distributed/wholegraph/nv_distributed_graph/__init__.py b/examples/distributed/wholegraph/nv_distributed_graph/__init__.py new file mode 100644 index 000000000000..6bf728997b0e --- /dev/null +++ b/examples/distributed/wholegraph/nv_distributed_graph/__init__.py @@ -0,0 +1,4 @@ +from .dist_graph import DistGraphCSC +from .dist_tensor import DistTensor, DistEmbedding +from .dist_shmem import init_process_group_per_node, get_local_process_group, get_local_root, get_local_rank, get_local_size, to_shmem +from .wholegraph import nvlink_network \ No newline at end of file diff --git a/examples/distributed/wholegraph/nv_distributed_graph/dist_graph.py b/examples/distributed/wholegraph/nv_distributed_graph/dist_graph.py new file mode 100644 index 000000000000..65f4e0314362 --- /dev/null +++ b/examples/distributed/wholegraph/nv_distributed_graph/dist_graph.py @@ -0,0 +1,92 @@ +from typing import Any, List, Union, Literal, Optional +import numpy as np + +import torch + +from . import dist_shmem +from . import dist_tensor + +class DistGraphCSC: + """ Distributed Graph Store based on DistTensors for Compressed Sparse Column (CSC) format. + Only support homogeneous graph for now. + Parameters + ---------- + node_tensor : torch.Tensor + The node tensor. + edge_tensor : torch.Tensor + """ + def __init__( + self, + col_ptrs_src: Optional[Union[torch.Tensor, str, List[str]]] = None, + row_indx_src : Optional[Union[torch.Tensor, str, List[str]]] = None, + device: Optional[Literal["cpu", "cuda"]] = "cpu", + pinned_shared: Optional[bool] = False, + partition_book: Optional[Union[List[int], None]] = None, # location memtype ?? backend?? ; engine; comm = vmm/nccl .. + backend: Optional[str] = "nccl", # reserved this for future use + *args, + **kwargs, + ): + # optionally to save node/edge feature tensors (view) + self.data = {} # place holder for the hetergenous graph + self.device = device + if partition_book is not None: + raise NotImplementedError("Uneven partition of 1-D disttensor is not turned on yet.") + + if pinned_shared: + dist_shmem.init_process_group_per_node() + # load the original dataset in the first process and share it with others + col_ptrs = None + row_indx = None + if dist_shmem.get_local_rank() == 0: + if isinstance(col_ptrs_src, torch.Tensor) and isinstance(row_indx_src, torch.Tensor): + col_ptrs = col_ptrs_src + row_indx = row_indx_src + elif col_ptrs_src.endswith('.pt') and row_indx_src.endswith('.pt'): + col_ptrs = torch.load(col_ptrs_src, mmap=True) + row_indx = torch.load(row_indx_src, mmap=True) + elif col_ptrs_src.endswith('.npy') and row_indx_src.endswith('.npy'): + col_ptrs = torch.from_numpy(np.load(col_ptrs_src, mmap_mode='c')) + row_indx = torch.from_numpy(np.load(row_indx_src, mmap_mode='c')) + else: + raise ValueError("Unsupported file format.") + + self.col_ptrs = dist_shmem.to_shmem(col_ptrs) + self.row_indx = dist_shmem.to_shmem(row_indx) + else: + # 2-gather approach here only + self.col_ptrs = dist_tensor.DistTensor(col_ptrs_src, device = device, backend = backend) + self.row_indx = dist_tensor.DistTensor(row_indx_src, device = device, backend = backend) + + @property + def num_nodes(self): + return self.col_ptrs.shape[0] - 1 + + @property + def num_edges(self): + return self.row_indx.shape[0] + + def __getitem__ (self, name: str) -> Any: + return self.data[name] + + def __setitem__(self, name: str, value: Any) -> None: + self.data[name] = value + return + + def transform_nodes(self, nodes): + """Transform all seed nodes from every rank to the local seed nodes + + Args: + nodes (_type_): _description_ + """ + pass + + def transform_edges(self, edges): + """Transform all seed edges from every rank to the local seed edges + + Args: + edges (_type_): _description_ + """ + pass + + def transform_graph() : #back to graph + pass diff --git a/examples/distributed/wholegraph/nv_distributed_graph/dist_shmem.py b/examples/distributed/wholegraph/nv_distributed_graph/dist_shmem.py new file mode 100644 index 000000000000..49dc6ac56a64 --- /dev/null +++ b/examples/distributed/wholegraph/nv_distributed_graph/dist_shmem.py @@ -0,0 +1,131 @@ +"""Utilities for launching distributed GNN tasks. """ +import os +import torch.distributed as dist +import torch.multiprocessing as mp + +_LOCAL_PROCESS_GROUP = None +_LOCAL_ROOT_GLOBAL_RANK = None +_LOCAL_ROOT_AUTH_KEY = None +nprocs_per_node = 1 + +def init_process_group_per_node(): + """Initialize the distributed process group for each node.""" + if _LOCAL_PROCESS_GROUP is None: + create_process_group_per_node() + return + else: + assert dist.get_process_group_ranks(_LOCAL_PROCESS_GROUP)[0] == _LOCAL_ROOT_GLOBAL_RANK + return + + +def create_process_group_per_node(): + """Create local process groups for each distinct node.""" + global _LOCAL_PROCESS_GROUP, _LOCAL_ROOT_GLOBAL_RANK + global nprocs_per_node + assert _LOCAL_PROCESS_GROUP is None and _LOCAL_ROOT_GLOBAL_RANK is None + assert dist.is_initialized(), "torch.distributed is not initialized. Please call torch.distributed.init_process_group() first." + + nprocs_per_node = int(os.environ.get('LOCAL_WORLD_SIZE', os.environ.get('SLURM_NTASKS_PER_NODE'))) + world_size = dist.get_world_size() + rank = dist.get_rank() if dist.is_initialized() else 0 + assert world_size % nprocs_per_node == 0 + + num_nodes = world_size // nprocs_per_node + node_id = rank // nprocs_per_node + for i in range(num_nodes): + node_ranks = list(range(i * nprocs_per_node, (i + 1) * nprocs_per_node)) + pg = dist.new_group(node_ranks) + if i == node_id: + _LOCAL_PROCESS_GROUP = pg + assert _LOCAL_PROCESS_GROUP is not None + _LOCAL_ROOT_GLOBAL_RANK = dist.get_process_group_ranks(_LOCAL_PROCESS_GROUP)[0] + + +def get_local_process_group(): + """Return a torch distributed process group for a subset of all local processes in the same node.""" + assert _LOCAL_PROCESS_GROUP is not None + return _LOCAL_PROCESS_GROUP + + +def get_local_root(): + """Return the global rank corresponding to the local root process.""" + assert _LOCAL_ROOT_GLOBAL_RANK is not None + return _LOCAL_ROOT_GLOBAL_RANK + + +def get_local_rank(): + """Return the local rank of the current process.""" + assert _LOCAL_PROCESS_GROUP is not None + return dist.get_rank(_LOCAL_PROCESS_GROUP) + + +def get_local_size(): + """Return the number of processes in the local process group.""" + assert _LOCAL_PROCESS_GROUP is not None + return dist.get_world_size(_LOCAL_PROCESS_GROUP) + + +def _sync_auth_key(local_group, local_root): + """Synchronize the authentication key across local process group or node. + + NOTE: In the context of MPI or torchrun, where all processes are launched concurrently and independently, + synchronized authentication key allows local processes seamlessly exchange data during pickling process. + """ + global _LOCAL_ROOT_AUTH_KEY + assert _LOCAL_ROOT_AUTH_KEY is None + if dist.get_rank() == local_root: + authkey = [bytes(mp.current_process().authkey)] + _LOCAL_ROOT_AUTH_KEY = authkey[0] + else: + authkey = [None] + + dist.broadcast_object_list(authkey, src=local_root, group=local_group) + if authkey[0] != bytes(mp.current_process().authkey): + mp.current_process().authkey = authkey[0] + _LOCAL_ROOT_AUTH_KEY = authkey[0] + assert _LOCAL_ROOT_AUTH_KEY == bytes(mp.current_process().authkey) + + +def _is_authkey_sync(): + global _LOCAL_ROOT_AUTH_KEY + return _LOCAL_ROOT_AUTH_KEY == bytes(mp.current_process().authkey) + + +def to_shmem(dataset): + """Move the dataset into shared memory. + + NOTE: This function performs dataset dumping/loading via a customizble pickler from the multiprocessing module. + Frameworks (e.g., DGL and PyTorch) have the capability to customize the pickling process for their specific + objects (e.g., DGLGraph or PyTorch Tensor), which involves moving the objects to shared memory at the local + root (ForkingPickler.dumps), and then making them accessible to all local processes (ForkingPickler.loads). + Parameters + ---------- + dataset : Tuple or List of supported objects + The objects can be DGLGraph and Pytorch Tensor, or any customized objects with the same mechanism + of using shared memory during pickling process. + + Returns + ------- + dataset : Reconstructed dataset in shared memory + Returned dataset preserves the same object hierarchy of the input. + + """ + local_root = get_local_root() + local_group = get_local_process_group() + if not _is_authkey_sync(): # if authkey not synced, sync the key + _sync_auth_key(local_group, local_root) + if dist.get_rank() == local_root: + # each non-root process should have a dedicated pickle.dumps() + handles = [None] + [ + bytes(mp.reductions.ForkingPickler.dumps(dataset)) + for _ in range(dist.get_world_size(group=local_group) - 1) + ] + else: + handles = [None] * dist.get_world_size(group=local_group) + dist.broadcast_object_list(handles, src=local_root, group=local_group) + handle = handles[dist.get_rank(group=local_group)] + if dist.get_rank() != local_root: + # only non-root process performs pickle.loads() + dataset = mp.reductions.ForkingPickler.loads(handle) + dist.barrier(group=local_group) # necessary to prevent unexpected close of any procs beyond this function + return dataset diff --git a/examples/distributed/wholegraph/nv_distributed_graph/dist_tensor.py b/examples/distributed/wholegraph/nv_distributed_graph/dist_tensor.py new file mode 100644 index 000000000000..dc67efb2f8c3 --- /dev/null +++ b/examples/distributed/wholegraph/nv_distributed_graph/dist_tensor.py @@ -0,0 +1,424 @@ +import atexit +from typing import List, Literal, Optional, Union + +import numpy as np +import torch +import torch.distributed as dist + +import pylibwholegraph +from .wholegraph import create_wg_dist_tensor, create_wg_dist_tensor_from_files, finalize_wholegraph, _wm_global +from .wholegraph import copy_host_global_tensor_to_local + +class DistTensor: + _instance_count = 0 + """ + WholeGraph-backed Distributed Tensor Interface for PyTorch. + Parameters + ---------- + src: Optional[Union[torch.Tensor, str, List[str]]] + The source of the tensor. It can be a torch.Tensor on host, a file path, or a list of file paths. + When the source is omitted, the tensor will be load later. + shape : Optional[list, tuple] + The shape of the tensor. It has to be a one- or two-dimensional tensor for now. + When the shape is omitted, the `src` has to be specified and must be `pt` or `npy` file paths. + dtype : Optional[torch.dtype] + The dtype of the tensor. The data type has to be the one in the deep learning framework. + Whne the dtype is omitted, the `src` has to be specified and must be `pt` or `npy` file paths. + device : Optional[Literal["cpu", "cuda"]] = "cpu" + The desired location to store the embedding [ "cpu" | "cuda" ]. Default is "cpu", i.e., pinned-host memory. + partition_book : Union[List[int], None] = None + 1-D Range partition based on entry (dim-0). partition_book[i] determines the + entry count of rank i and shoud be a positive integer; the sum of partition_book should equal to shape[0]. + Entries will be equally partitioned if None. + backend : Optional[Literal["vmm", "nccl", "nvshmem", "chunked"]] = "nccl" + The backend used for communication. Default is "nccl". + """ + def __init__( + self, + src: Optional[Union[torch.Tensor, str, List[str]]] = None, + shape: Optional[Union[list, tuple]] = None, + dtype: Optional[torch.dtype] = None, + device: Optional[Literal["cpu", "cuda"]] = "cpu", + partition_book: Optional[Union[List[int], None]] = None, # location memtype ?? backend?? ; engine; comm = vmm/nccl .. + backend: Optional[str] = "nccl", + *args, + **kwargs, + ): + if DistTensor._instance_count == 0 and _wm_global is False: + # Register the cleanup function for safty exit + atexit.register(finalize_wholegraph) + + self._tensor = None # WholeMemory tensor for now. In future, we may support other types of distributed tensors. + self._device = device + if src is None: + # Create an empty WholeGraph tensor + assert shape is not None, "Please specify the shape of the tensor." + assert dtype is not None, "Please specify the dtype of the tensor." + assert len(shape) == 1 or len(shape) == 2, "The shape of the tensor must be 1D or 2D." + self._tensor = create_wg_dist_tensor(list(shape), dtype, device, partition_book, backend, *args, **kwargs) + self._dtype = dtype + else: + if isinstance(src, list): + # A list of file paths for a tensor + # Only support the binary file format directly loaded via WM API for now + # TODO (@liuc): support merging multiple pt or npy files to create a tensor + assert shape is not None and dtype is not None, "For now, read from multiple files are only supported in binary format." + self._tensor = create_wg_dist_tensor_from_files(src, shape, dtype, device, partition_book, backend, *args, **kwargs) + #self._tensor.from_filelist(src) + self._dtype = dtype + else: + if isinstance(src, torch.Tensor): + self._tensor = create_wg_dist_tensor(list(src.shape), src.dtype, device, partition_book, backend, *args, **kwargs) + self._dtype = src.dtype + host_tensor = src + elif isinstance(src, str) and src.endswith('.pt'): + host_tensor = torch.load(src, mmap=True) + self._tensor = create_wg_dist_tensor(list(host_tensor.shape), host_tensor.dtype, device, partition_book, backend, *args, **kwargs) + self._dtype = host_tensor.dtype + elif isinstance(src, str) and src.endswith('.npy'): + host_tensor = torch.from_numpy(np.load(src, mmap_mode='c')) + self._dtype = host_tensor.dtype + self._tensor = create_wg_dist_tensor(list(host_tensor.shape), host_tensor.dtype, device, partition_book, backend, *args, **kwargs) + else: + raise ValueError("Unsupported source type. Please provide a torch.Tensor, a file path, or a list of file paths.") + + self.load_from_global_tensor(host_tensor) + DistTensor._instance_count += 1 # increase the instance count to track for resource cleanup + + def load_from_global_tensor(self, tensor): + # input pytorch host tensor (mmapped or in shared host memory), and copy to wholegraph tensor + assert self._tensor is not None, "Please create WholeGraph tensor first." + self._dtype = tensor.dtype + if isinstance(self._tensor, pylibwholegraph.torch.WholeMemoryEmbedding): + _tensor = self._tensor.get_embedding_tensor() + else: + _tensor = self._tensor + copy_host_global_tensor_to_local(_tensor, tensor, _tensor.get_comm()) + + def load_from_local_tensor(self, tensor): + # input pytorch host tensor (mmapped or in shared host memory), and copy to wholegraph tensor + assert self._tensor is not None, "Please create WholeGraph tensor first." + assert self._tensor.local_shape == tensor.shape, "The shape of the tensor does not match the shape of the local tensor." + assert self._dtype == tensor.dtype, "The dtype of the tensor does not match the dtype of the local tensor." + if isinstance(self._tensor, pylibwholegraph.torch.WholeMemoryEmbedding): + self._tensor.get_embedding_tensor().get_local_tensor().copy_(tensor) + else: + self._tensor.get_local_tensor().copy_(tensor) + + + @classmethod + def from_tensor(cls, tensor: torch.Tensor, device: Optional[Literal["cpu", "cuda"]] = "cpu", partition_book: Union[List[int], None] = None, backend: Optional[str] = 'nccl'): + """ + Create a WholeGraph-backed Distributed Tensor from a PyTorch tensor. + + Parameters + ---------- + tensor : torch.Tensor + The PyTorch tensor to be copied to the WholeGraph tensor. + device : str, optional + The desired location to store the embedding [ "cpu" | "cuda" ]. Default is "cpu". + backend : str, optional + The backend used for communication. Default is "nccl". + + Returns + ------- + DistTensor + The WholeGraph-backed Distributed Tensor. + """ + return cls(src=tensor, device=device, partition_book=partition_book, backend=backend) + + @classmethod + def from_file(cls, file_path: str, device: Optional[Literal["cpu", "cuda"]] = "cpu", partition_book: Union[List[int], None] = None, backend: Optional[str] = 'nccl'): + """ + Create a WholeGraph-backed Distributed Tensor from a file. + + Parameters + ---------- + file_path : str + The file path to the tensor. The file can be in the format of PyTorch tensor or NumPy array. + device : str, optional + The desired location to store the embedding [ "cpu" | "cuda" ]. Default is "cpu". + backend : str, optional + The backend used for communication. Default is "nccl". + + Returns + ------- + DistTensor + The WholeGraph-backed Distributed Tensor. + """ + return cls(src=file_path, device=device, partition_book=partition_book, backend=backend) + + + def __setitem__(self, idx: torch.Tensor, val: torch.Tensor): + """ + Set the embeddings for the specified node indices. + This call must be called by all processes. + + Parameters + ---------- + idx : torch.Tensor + Index of the embeddings to collect. + val : torch.Tensor + The requested node embeddings. + """ + assert self._tensor is not None, "Please create WholeGraph tensor first." + idx = idx.cuda() + val = val.cuda() + + if val.dtype != self.dtype: + val = val.to(self.dtype) + self._tensor.scatter(val, idx) + + def __getitem__(self, idx: torch.Tensor) -> torch.Tensor: + """ + Get the embeddings for the specified node indices (remotely). + This call must be called by all processes. + + Parameters + ---------- + idx : torch.Tensor + Index of the embeddings to collect. + Returns + ------- + torch.Tensor + The requested node embeddings. + """ + assert self._tensor is not None, "Please create WholeGraph tensor first." + idx = idx.cuda() + output_tensor = self._tensor.gather(idx) # output_tensor is on cuda by default + return output_tensor + + def get_local_tensor(self, host_view=False): + """ + Get the local embedding tensor and its element offset at current rank. + + Returns + ------- + (torch.Tensor, int) + Tuple of local torch Tensor (converted from DLPack) and its offset. + """ + local_tensor, offset = self._tensor.get_local_tensor(host_view = host_view) + return local_tensor + + def get_local_offset(self): + """ + Get the local embedding tensor and its element offset at current rank. + + Returns + ------- + (torch.Tensor, int) + Tuple of local torch Tensor (converted from DLPack) and its offset. + """ + _, offset = self._tensor.get_local_tensor() + return offset + + def get_comm(self): + """ + Get the communicator of the WholeGraph embedding. + + Returns + ------- + WholeMemoryCommunicator + The WholeGraph global communicator of the WholeGraph embedding. + """ + assert self._tensor is not None, "Please create WholeGraph tensor first." + return self._tensor.get_comm() + + @property + def dim(self): + return self._tensor.dim() + + @property + def shape(self): + return self._tensor.shape + + @property + def device(self): + return self._device + + @property + def dtype(self): + return self._dtype + + def __repr__(self): + if self._tensor is None: + return f"" + + # Format the output similar to PyTorch + tensor_repr = f"DistTensor(" + tensor_repr += f"shape={self._tensor.shape}, dtype={self._dtype}, device='{self._device}')" + return tensor_repr + + def __del__(self): + # Decrease instance count when an instance is deleted + DistTensor._instance_count -= 1 + if DistTensor._instance_count == 0: + finalize_wholegraph() + +class DistEmbedding(DistTensor): + """ + WholeGraph-backed Distributed Embedding Interface for PyTorch. + Parameters + ---------- + src: Optional[Union[torch.Tensor, str, List[str]]] + The source of the tensor. It can be a torch.Tensor on host, a file path, or a list of file paths. + When the source is omitted, the tensor will be load later. + shape : Optional[list, tuple] + The shape of the tensor. It has to be a one- or two-dimensional tensor for now. + When the shape is omitted, the `src` has to be specified and must be `pt` or `npy` file paths. + dtype : Optional[torch.dtype] + The dtype of the tensor. The data type has to be the one in the deep learning framework. + Whne the dtype is omitted, the `src` has to be specified and must be `pt` or `npy` file paths. + device : Optional[Literal["cpu", "cuda"]] = "cpu" + The desired location to store the embedding [ "cpu" | "cuda" ]. Default is "cpu", i.e., pinned-host memory. + partition_book : Union[List[int], None] = None + 1-D Range partition based on entry (dim-0). partition_book[i] determines the + entry count of rank i and shoud be a positive integer; the sum of partition_book should equal to shape[0]. + Entries will be equally partitioned if None. + backend : Optional[Literal["vmm", "nccl", "nvshmem", "chunked"]] = "nccl" + The backend used for communication. Default is "nccl". + cache_policy : Optional[WholeMemoryCachePolicy] = None + The cache policy for the tensor if it is an embedding. Default is None. + gather_sms : Optional[int] = -1 + Whether to gather the embeddings on all GPUs. Default is False. + round_robin_size: int = 0 + continuous embedding size of a rank using round robin shard strategy + name : Optional[str] + The name of the tensor. + """ + def __init__( + self, + src: Optional[Union[torch.Tensor, str, List[str]]] = None, + shape: Optional[Union[list, tuple]] = None, + dtype: Optional[torch.dtype] = None, + device: Optional[Literal["cpu", "cuda"]] = "cpu", + partition_book: Union[List[int], None] = None, + backend: Optional[str] = "nccl", + cache_policy = None, #Optional[pylibwholegraph.WholeMemoryCachePolicy] = None, + gather_sms: Optional[int] = -1, + round_robin_size: int = 0, + name: Optional[str] = None, + ): + self._name = name + + super().__init__(src, shape, dtype, device, partition_book, backend, cache_policy=cache_policy, gather_sms=gather_sms, round_robin_size=round_robin_size) + self._embedding = self._tensor # returned _tensor is a WmEmbedding object + self._tensor = self._embedding.get_embedding_tensor() + + @classmethod + def from_tensor( + cls, + tensor: torch.Tensor, + device: Literal["cpu", "cuda"] = "cpu", + partition_book: Union[List[int], None] = None, + name: Optional[str] = None, + cache_policy = None, + *args, + **kwargs + ): + """ + Create a WholeGraph-backed Distributed Embedding (hooked with PyT's grad tracing) from a PyTorch tensor. + + Parameters + ---------- + tensor : torch.Tensor + The PyTorch tensor to be copied to the WholeGraph tensor. + device : str, optional + The desired location to store the embedding [ "cpu" | "cuda" ]. Default is "cpu". + name : str, optional + The name of the tensor. + + Returns + ------- + DistEmbedding + The WholeGraph-backed Distributed Tensor. + """ + return cls(tensor, device, partition_book, name, cache_policy, *args, **kwargs) + + @classmethod + def from_file( + cls, + file_path: str, + device: Literal["cpu", "cuda"] = "cpu", + partition_book: Union[List[int], None] = None, + name: Optional[str] = None, + cache_policy = None, + *args, + **kwargs + ): + """ + Create a WholeGraph-backed Distributed Tensor from a file. + + Parameters + ---------- + file_path : str + The file path to the tensor. The file can be in the format of PyTorch tensor or NumPy array. + device : str, optional + The desired location to store the embedding [ "cpu" | "cuda" ]. Default is "cpu". + name : str, optional + The name of the tensor. + + Returns + ------- + DistTensor + The WholeGraph-backed Distributed Tensor. + """ + return cls(file_path, device, partition_book, name, cache_policy, *args, **kwargs) + + + def __setitem__(self, idx: torch.Tensor, val: torch.Tensor): + """ + Set the embeddings for the specified node indices. + This call must be called by all processes. + + Parameters + ---------- + idx : torch.Tensor + Index of the embeddings to collect. + val : torch.Tensor + The requested node embeddings. + """ + assert self._tensor is not None, "Please create WholeGraph tensor first." + idx = idx.cuda() + val = val.cuda() + + if val.dtype != self.dtype: + val = val.to(self.dtype) + self._embedding.get_embedding_tensor().scatter(val, idx) + + def __getitem__(self, idx: torch.Tensor) -> torch.Tensor: + """ + Get the embeddings for the specified node indices (remotely). + This call must be called by all processes. + + Parameters + ---------- + idx : torch.Tensor + Index of the embeddings to collect. + Returns + ------- + torch.Tensor + The requested node embeddings. + """ + assert self._tensor is not None, "Please create WholeGraph tensor first." + idx = idx.cuda() + output_tensor = self._embedding.gather(idx) # output_tensor is on cuda by default + return output_tensor + + @property + def name(self): + return self._name + + def __repr__(self): + if self._embedding is None: + return f"" + + # Format the output similar to PyTorch + tensor_repr = f"DistEmbedding(" + if self._name: + tensor_repr += f"name={self._name}, " + tensor_repr += f"shape={self.shape}, dtype={self.dtype}, device='{self.device}')" + return tensor_repr + + def __del__(self): + super().__del__() \ No newline at end of file diff --git a/examples/distributed/wholegraph/nv_distributed_graph/wholegraph.py b/examples/distributed/wholegraph/nv_distributed_graph/wholegraph.py new file mode 100644 index 000000000000..6e44538dc150 --- /dev/null +++ b/examples/distributed/wholegraph/nv_distributed_graph/wholegraph.py @@ -0,0 +1,250 @@ +from typing import Tuple, List, Union + +import torch +import torch.distributed as dist + +import pylibwholegraph.torch as wgth +from . import dist_shmem + +_wm_global = False + +def init_wholegraph(): + global _wm_global + + if _wm_global is True: + return wgth.comm.get_global_communicator("nccl") + dist_shmem.init_process_group_per_node() + local_size = dist_shmem.get_local_size() + local_rank = dist_shmem.get_local_rank() + + wgth.init(dist.get_rank(), dist.get_world_size(), local_rank, local_size=local_size, wm_log_level="info") + print(f"[Rank {dist.get_rank()}] WholeGraph Initialization: " + f"{dist.get_world_size()} GPUs are used with {local_size} GPUs per node.") + global_comm = wgth.comm.get_global_communicator("nccl") + _wm_global = True + return global_comm + +def finalize_wholegraph(): + global _wm_global + if _wm_global is False: + return + wgth.finalize() + _wm_global = False + +def nvlink_network(): + r""" + Check if the current hardware supports cross-node NVLink network. + """ + if not _wm_global: + raise RuntimeError("WholeGraph is not initialized.") + + global_comm = wgth.comm.get_global_communicator("nccl") + local_size = dist_shmem.get_local_size() + world_size = dist.get_world_size() + + # Intra-node communication + if local_size == world_size: + # use WholeGraph to check if the current hardware supports direct p2p + return global_comm.support_type_location('continuous', 'cuda') + + # Check for multi-node support + is_cuda_supported = global_comm.support_type_location('continuous', 'cuda') + is_cpu_supported = global_comm.support_type_location('continuous', 'cpu') + + if is_cuda_supported and is_cpu_supported: + return True + + return False + +def copy_host_global_tensor_to_local( + wm_tensor, host_tensor, wm_comm +): + local_tensor, local_start = wm_tensor.get_local_tensor( + host_view=False + ) + ## enable these checks when the wholegraph is updated to 24.10 + #local_ref_start = wm_tensor.get_local_entry_start() + #local_ref_count = wm_tensor.get_local_entry_count() + #assert local_start == local_ref_start + #assert local_tensor.shape[0] == local_ref_count + local_tensor.copy_(host_tensor[local_start : local_start + local_tensor.shape[0]]) + wm_comm.barrier() + +def create_pyg_subgraph(WG_SampleOutput) -> Tuple: + # PyG_SampleOutput (node, row, col, edge, batch...): + # node (torch.Tensor): The sampled nodes in the original graph. + # row (torch.Tensor): The source node indices of the sampled subgraph. + # Indices must be within {0, ..., num_nodes - 1} where num_nodes is the number of nodes in sampled graph. + # col (torch.Tensor): The destination node indices of the sampled subgraph. Indices must be within {0, ..., num_nodes - 1} + # edge (torch.Tensor, optional): The sampled edges in the original graph. (for obtain edge features from the original graph) + # batch (torch.Tensor, optional): The vector to identify the seed node for each sampled node in case of disjoint subgraph + # sampling per seed node. (None) + # num_sampled_nodes (List[int], optional): The number of sampled nodes per hop. + # num_sampled_edges (List[int], optional): The number of sampled edges per hop. + sampled_nodes_list, edge_indice_list, csr_row_ptr_list, csr_col_ind_list = WG_SampleOutput + num_sampled_nodes = [] + node = sampled_nodes_list[0] + + for hop in range(len(sampled_nodes_list)-1): + sampled_nodes = len(sampled_nodes_list[hop]) - len(sampled_nodes_list[hop+1]) + num_sampled_nodes.append(sampled_nodes) + num_sampled_nodes.append(len(sampled_nodes_list[-1])) + num_sampled_nodes.reverse() + + layers = len(edge_indice_list) + num_sampled_edges = [len(csr_col_ind_list[-1])] + # Loop in reverse order, starting from the second last layer + for layer in range(layers - 2, -1, -1): + num_sampled_edges.append(len(csr_col_ind_list[layer] - len(csr_col_ind_list[layer + 1]))) + + row = csr_col_ind_list[0] # rows + col = edge_indice_list[0][1] # dst node + + edge = None + batch = None + return node, row, col, edge, batch, num_sampled_nodes, num_sampled_edges + +def sample_nodes_wmb_fn(wg_sampler, seeds, fanouts): + # WG_SampleOutput (target_gids, edge_indice, csr_row_ptr, csr_col_ind): + # target_gids [1D tensors]: unique sampled global node ids for each layer + # edge_indice [rank-2 tensors]: edge list [src, des] for each layer, in local node id (start from 0 for each layer) to confirm + # csr_row_ptr [1D tensors]: csr row ptrs for each subgraph in each layer (starting from 0 for each subgraph) + # csr_col_ind [1D tensors]: csr col indx for each subgraph in each layer + WG_SampleOutput = wg_sampler.multilayer_sample_without_replacement(seeds, fanouts, None) + return create_pyg_subgraph(WG_SampleOutput) + +def create_wg_dist_tensor( + shape: list, + dtype: torch.dtype, + location: str = "cpu", + partition_book: Union[List[int], None] = None, # default is even partition + backend: str = "nccl", # default is nccl; support nccl, vmm, nvshmem... + **kwargs +): + """Create a WholeGraph-managed distributed tensor. + + Parameters + ---------- + shape : list + The shape of the tensor. It has to be a two-dimensional and one-dimensional tensor for now. + The first dimension typically is the number of nodes. + The second dimension is the feature/embedding dimension. + dtype : torch.dtype + The dtype of the tensor. The data type has to be the one in the deep learning framework. + location : str, optional + The desired location to store the embedding [ "cpu" | "cuda" ] + partition_book : list, optional + The partition book for the embedding tensor. The length of the partition book should be the same as the number of ranks. + backend : str, optional + The backend for the distributed tensor [ "nccl" | "vmm" | "nvshmem" ] (nvshmem not turned on in this example) + """ + global_comm = init_wholegraph() + + if backend == "nccl": + embedding_wholememory_type = "distributed" + elif backend == "vmm": + embedding_wholememory_type = "continuous" + elif backend == "nvshmem": + raise NotImplementedError("NVSHMEM backend has not turned on yet.") + else: + raise ValueError(f"Unsupported backend: {backend}") + embedding_wholememory_location = location + + if "cache_policy" in kwargs: + assert len(shape) == 2, "The shape of the embedding tensor must be 2D." + cache_policy = kwargs['cache_policy'] + kwargs.pop('cache_policy') + + wm_embedding = wgth.create_embedding( + global_comm, + embedding_wholememory_type, + embedding_wholememory_location, + dtype, + shape, + cache_policy=cache_policy, # disable cache for now + #embedding_entry_partition=partition_book, + **kwargs + #tensor_entry_partition=None # important to do load balance + ) + else: + assert len(shape) == 2 or len(shape) == 1, "The shape of the tensor must be 2D or 1D." + wm_embedding = wgth.create_wholememory_tensor( + global_comm, + embedding_wholememory_type, + embedding_wholememory_location, + shape, + dtype, + strides=None, + #tensor_entry_partition=partition_book # important to do load balance + ) + return wm_embedding + +def create_wg_dist_tensor_from_files( + file_list: List[str], + shape: list, + dtype: torch.dtype, + location: str = "cpu", + partition_book: Union[List[int], None] = None, # default is even partition + backend: str = "nccl", # default is nccl; support nccl, vmm, nvshmem... + **kwargs +): + """Create a WholeGraph-managed distributed tensor from a list of files. + + Parameters + ---------- + file_list : list + The list of files to load the embedding tensor. + shape : list + The shape of the tensor. It has to be a two-dimensional and one-dimensional tensor for now. + The first dimension typically is the number of nodes. + The second dimension is the feature/embedding dimension. + dtype : torch.dtype + The dtype of the tensor. The data type has to be the one in the deep learning framework. + location : str, optional + The desired location to store the embedding [ "cpu" | "cuda" ] + partition_book : list, optional + The partition book for the embedding tensor. The length of the partition book should be the same as the number of ranks. + backend : str, optional + The backend for the distributed tensor [ "nccl" | "vmm" | "nvshmem" ] (nvshmem not turned on in this example) + """ + global_comm = init_wholegraph() + + if backend == "nccl": + embedding_wholememory_type = "distributed" + elif backend == "vmm": + embedding_wholememory_type = "continuous" + elif backend == "nvshmem": + raise NotImplementedError("NVSHMEM backend has not turned on yet.") + else: + raise ValueError(f"Unsupported backend: {backend}") + embedding_wholememory_location = location + + if "cache_policy" in kwargs: + assert len(shape) == 2, "The shape of the embedding tensor must be 2D." + cache_policy = kwargs['cache_policy'] + kwargs.pop('cache_policy') + + wm_embedding = wgth.create_embedding_from_filelist( + global_comm, + embedding_wholememory_type, + embedding_wholememory_location, + file_list, + dtype, + shape[1], + cache_policy=cache_policy, # disable cache for now + #embedding_entry_partition=partition_book, + **kwargs + ) + else: + assert len(shape) == 2 or len(shape) == 1, "The shape of the tensor must be 2D or 1D." + last_dim_size = 0 if len(shape) == 1 else shape[1] + wm_embedding = wgth.create_wholememory_tensor_from_filelist( + global_comm, + embedding_wholememory_type, + embedding_wholememory_location, + file_list, + dtype, + last_dim_size, + #tensor_entry_partition=partition_book # important to do load balance + ) + return wm_embedding \ No newline at end of file diff --git a/examples/distributed/wholegraph/papers100m_dist_wholegraph_nc.py b/examples/distributed/wholegraph/papers100m_dist_wholegraph_nc.py new file mode 100644 index 000000000000..9de67785c289 --- /dev/null +++ b/examples/distributed/wholegraph/papers100m_dist_wholegraph_nc.py @@ -0,0 +1,194 @@ +"""Multi-node multi-GPU example on ogbn-papers100m. + +Example way to run using srun: +srun -l -N --ntasks-per-node= \ +--container-name=cont --container-image= \ +--container-mounts=/ogb-papers100m/:/workspace/dataset +python3 path_to_script.py +""" +import os +import time +from typing import Optional + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from ogb.nodeproppred import PygNodePropPredDataset +from torch.nn.parallel import DistributedDataParallel +from torchmetrics import Accuracy + +from torch_geometric.loader import NodeLoader +from torch_geometric.nn import GCN + +import torch_geometric.transforms as T +from torch_geometric.sampler import BaseSampler + +from nv_distributed_graph import dist_shmem +from feature_store import WholeGraphFeatureStore +from graph_store import WholeGraphGraphStore + + +class WholeGraphSampler(BaseSampler): + r""" + A naive sampler class for WholeGraph graph storage that only supports uniform node-based sampling on homogeneous graph. + """ + from torch_geometric.sampler import SamplerOutput, NodeSamplerInput + def __init__( + self, + graph: WholeGraphGraphStore, + num_neighbors, + ): + import pylibwholegraph.torch as wgth + + self.num_neighbors = num_neighbors + self.wg_sampler = wgth.GraphStructure() + row_indx, col_ptrs, _ = graph.csc() + self.wg_sampler.set_csr_graph(col_ptrs._tensor, row_indx._tensor) + + def sample_from_nodes( + self, + inputs: NodeSamplerInput + ) -> SamplerOutput: + r""" + Sample subgraphs from the given nodes based on uniform node-based sampling. + """ + seed = inputs.node.cuda(non_blocking=True) # WholeGraph Sampler needs all seeds on device + WG_SampleOutput = self.wg_sampler.multilayer_sample_without_replacement(seed, self.num_neighbors, None) + out = WholeGraphGraphStore.create_pyg_subgraph(WG_SampleOutput) + out.metadata = (inputs.input_id, inputs.time) + return out + +def run(world_size, rank, local_rank, device): + wall_clock_start = time.perf_counter() + + # Will query the runtime environment for `MASTER_ADDR` and `MASTER_PORT`. + # Make sure, those are set! + dist.init_process_group('nccl', world_size=world_size, rank=rank) + dist_shmem.init_process_group_per_node() + + transform = T.Compose([T.ToUndirected(), T.ToSparseTensor()]) + + # Load the dataset in the local root process and share it with others + if dist_shmem.get_local_rank() == 0: + # Use pre_transform to avoid on-fly graph format conversion and reduce RAM usage during runtime, especially for large graphs. + dataset = PygNodePropPredDataset(name='ogbn-papers100M', root='/workspace', pre_transform=transform) + else: + dataset = None + dataset = dist_shmem.to_shmem(dataset) # move dataset to shmem for local ranks access + + split_idx = dataset.get_idx_split() + split_idx['train'] = split_idx['train'].split( + split_idx['train'].size(0) // world_size, dim=0)[rank].clone() + split_idx['valid'] = split_idx['valid'].split( + split_idx['valid'].size(0) // world_size, dim=0)[rank].clone() + split_idx['test'] = split_idx['test'].split( + split_idx['test'].size(0) // world_size, dim=0)[rank].clone() + num_features = dataset.num_features + num_classes = dataset.num_classes + + data = dataset[0] + feature_store = WholeGraphFeatureStore(pyg_data=data) + graph_store = WholeGraphGraphStore(pyg_data=data) + + kwargs = dict( + data=(feature_store, graph_store), + batch_size=1024, + num_workers=0, # with wholegraph graph store you don't need workers + filter_per_worker=False, # WholeGraph feature fetching is not fork-safe + ) + + node_sampler = WholeGraphSampler( + graph_store, + num_neighbors=[30, 30], + ) + + train_loader = NodeLoader( + input_nodes=split_idx['train'], + node_sampler=node_sampler, + shuffle=True, + drop_last=True, + **kwargs, + ) + val_loader = NodeLoader(input_nodes=split_idx['valid'], node_sampler=node_sampler, **kwargs) + test_loader = NodeLoader(input_nodes=split_idx['test'], node_sampler=node_sampler, **kwargs) + + eval_steps = 1000 + model = GCN(num_features, 256, 2, num_classes) + acc = Accuracy(task="multiclass", num_classes=num_classes).to(device) + model = DistributedDataParallel(model.to(device), device_ids=[local_rank]) + optimizer = torch.optim.Adam(model.parameters(), lr=0.001, + weight_decay=5e-4) + + if rank == 0: + prep_time = round(time.perf_counter() - wall_clock_start, 2) + print("Total time before training begins (prep_time)=", prep_time, + "seconds") + print("Beginning training...") + + for epoch in range(1, 21): + dist.barrier() + start = time.time() + model.train() + for i, batch in enumerate(train_loader): + batch = batch.to(device) + optimizer.zero_grad() + y = batch.y[:batch.batch_size].view(-1).to(torch.long) + out = model(batch.x, batch.edge_index)[:batch.batch_size] + loss = F.cross_entropy(out, y) + loss.backward() + optimizer.step() + + if rank == 0 and i % 100 == 0: + print(f'Epoch: {epoch:02d}, Iteration: {i}, Loss: {loss:.4f}') + + # Profile run: + # We synchronize before barrier to flush GPU OPs first, + # then adding barrier to sync CPUs to find max train time among all ranks. + torch.cuda.synchronize() + dist.barrier() + epoch_end = time.time() + + @torch.no_grad() + def test(loader: NodeLoader, num_steps: Optional[int] = None): + model.eval() + for j, batch in enumerate(loader): + if num_steps is not None and j >= num_steps: + break + batch = batch.to(device) + out = model(batch.x, batch.edge_index)[:batch.batch_size] + y = batch.y[:batch.batch_size].view(-1).to(torch.long) + acc(out, y) + acc_sum = acc.compute() + return acc_sum + + eval_acc = test(val_loader, num_steps=eval_steps) + if rank == 0: + print(f"Val Accuracy: {eval_acc:.4f}%", ) + print( + f"Epoch {epoch:05d} | " + f"Accuracy {eval_acc:.4f} | " + f"Time {epoch_end - start:.2f}" + ) + + acc.reset() + dist.barrier() + + test_acc = test(test_loader) + if rank == 0: + print(f"Test Accuracy: {test_acc:.4f}%", ) + dist.destroy_process_group() if dist.is_initialized() else None + +if __name__ == '__main__': + # Get the world size from the WORLD_SIZE variable or directly from SLURM: + world_size = int( + os.environ.get('WORLD_SIZE', os.environ.get('SLURM_NTASKS'))) + # Likewise for RANK and LOCAL_RANK: + rank = int(os.environ.get('RANK', os.environ.get('SLURM_PROCID'))) + local_rank = int( + os.environ.get('LOCAL_RANK', os.environ.get('SLURM_LOCALID'))) + + assert torch.cuda.is_available() + device = torch.device(local_rank) + torch.cuda.set_device(device) + run(world_size, rank, local_rank, device) + From 9f170b0e4563e3fc63fd85e48324309269394157 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 17 Oct 2024 23:06:00 +0000 Subject: [PATCH 2/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/distributed/wholegraph/README | 2 +- .../distributed/wholegraph/benchmark_data.py | 66 ++++--- .../distributed/wholegraph/feature_store.py | 70 ++++--- .../distributed/wholegraph/graph_store.py | 52 ++--- .../nv_distributed_graph/__init__.py | 2 +- .../nv_distributed_graph/dist_graph.py | 52 ++--- .../nv_distributed_graph/dist_shmem.py | 26 ++- .../nv_distributed_graph/dist_tensor.py | 187 ++++++++++-------- .../nv_distributed_graph/wholegraph.py | 93 +++++---- .../papers100m_dist_wholegraph_nc.py | 60 +++--- 10 files changed, 343 insertions(+), 267 deletions(-) diff --git a/examples/distributed/wholegraph/README b/examples/distributed/wholegraph/README index f10147c64b05..6883b06df9e0 100644 --- a/examples/distributed/wholegraph/README +++ b/examples/distributed/wholegraph/README @@ -59,4 +59,4 @@ torchrun --nnodes 1 --nproc-per-node benchmark_data.py --mode UV ### WholeGraph FeatureStore + GraphStore (UVA for feature and graph store access) ``` torchrun --nnodes 1 --nproc-per-node benchmark_data.py --mode UVA -``` \ No newline at end of file +``` diff --git a/examples/distributed/wholegraph/benchmark_data.py b/examples/distributed/wholegraph/benchmark_data.py index ee5fa5dbe94a..2868668b8b31 100644 --- a/examples/distributed/wholegraph/benchmark_data.py +++ b/examples/distributed/wholegraph/benchmark_data.py @@ -6,32 +6,31 @@ --container-mounts=/ogb-papers100m/:/workspace/dataset python3 path_to_script.py """ +import argparse import os import time from typing import Optional -import argparse import torch import torch.distributed as dist import torch.nn.functional as F +from feature_store import WholeGraphFeatureStore +from graph_store import WholeGraphGraphStore +from nv_distributed_graph import dist_shmem from ogb.nodeproppred import PygNodePropPredDataset from torch.nn.parallel import DistributedDataParallel from torchmetrics import Accuracy -from torch_geometric.loader import NodeLoader, NeighborLoader +from torch_geometric.loader import NeighborLoader, NodeLoader from torch_geometric.nn import GCN - from torch_geometric.sampler import BaseSampler -from nv_distributed_graph import dist_shmem -from feature_store import WholeGraphFeatureStore -from graph_store import WholeGraphGraphStore class WholeGraphSampler(BaseSampler): - r""" - A naive sampler class for WholeGraph graph storage that only supports uniform node-based sampling on homogeneous graph. + r"""A naive sampler class for WholeGraph graph storage that only supports uniform node-based sampling on homogeneous graph. """ - from torch_geometric.sampler import SamplerOutput, NodeSamplerInput + from torch_geometric.sampler import NodeSamplerInput, SamplerOutput + def __init__( self, graph: WholeGraphGraphStore, @@ -44,19 +43,18 @@ def __init__( row_indx, col_ptrs, _ = graph.csc() self.wg_sampler.set_csr_graph(col_ptrs._tensor, row_indx._tensor) - def sample_from_nodes( - self, - inputs: NodeSamplerInput - ) -> SamplerOutput: - r""" - Sample subgraphs from the given nodes based on uniform node-based sampling. + def sample_from_nodes(self, inputs: NodeSamplerInput) -> SamplerOutput: + r"""Sample subgraphs from the given nodes based on uniform node-based sampling. """ - seed = inputs.node.cuda(non_blocking=True) # WholeGraph Sampler needs all seeds on device - WG_SampleOutput = self.wg_sampler.multilayer_sample_without_replacement(seed, self.num_neighbors, None) + seed = inputs.node.cuda( + non_blocking=True) # WholeGraph Sampler needs all seeds on device + WG_SampleOutput = self.wg_sampler.multilayer_sample_without_replacement( + seed, self.num_neighbors, None) out = WholeGraphGraphStore.create_pyg_subgraph(WG_SampleOutput) out.metadata = (inputs.input_id, inputs.time) return out + def run(world_size, rank, local_rank, device, mode): wall_clock_start = time.perf_counter() @@ -67,10 +65,11 @@ def run(world_size, rank, local_rank, device, mode): # Load the dataset in the local root process and share it with local ranks if dist_shmem.get_local_rank() == 0: - dataset = PygNodePropPredDataset(name='ogbn-products', root='/workspace') + dataset = PygNodePropPredDataset(name='ogbn-products', + root='/workspace') else: dataset = None - dataset = dist_shmem.to_shmem(dataset) # move dataset to shmem + dataset = dist_shmem.to_shmem(dataset) # move dataset to shmem split_idx = dataset.get_idx_split() split_idx['train'] = split_idx['train'].split( @@ -109,7 +108,8 @@ def run(world_size, rank, local_rank, device, mode): batch_size=1024, num_neighbors=[30, 30], num_workers=4, - filter_per_worker=False, # WholeGraph feature fetching is not fork-safe + filter_per_worker= + False, # WholeGraph feature fetching is not fork-safe ) train_loader = NeighborLoader( input_nodes=split_idx['train'], @@ -127,8 +127,9 @@ def run(world_size, rank, local_rank, device, mode): kwargs = dict( data=data, batch_size=1024, - num_workers=0, # with wholegraph sampler you don't need workers - filter_per_worker=False, # WholeGraph feature fetching is not fork-safe + num_workers=0, # with wholegraph sampler you don't need workers + filter_per_worker= + False, # WholeGraph feature fetching is not fork-safe ) node_sampler = WholeGraphSampler( graph_store, @@ -141,15 +142,17 @@ def run(world_size, rank, local_rank, device, mode): drop_last=True, **kwargs, ) - val_loader = NodeLoader(input_nodes=split_idx['valid'], node_sampler=node_sampler, **kwargs) - test_loader = NodeLoader(input_nodes=split_idx['test'], node_sampler=node_sampler, **kwargs) + val_loader = NodeLoader(input_nodes=split_idx['valid'], + node_sampler=node_sampler, **kwargs) + test_loader = NodeLoader(input_nodes=split_idx['test'], + node_sampler=node_sampler, **kwargs) eval_steps = 1000 model = GCN(num_features, 256, 2, num_classes) acc = Accuracy(task="multiclass", num_classes=num_classes).to(device) model = DistributedDataParallel(model.to(device), device_ids=[local_rank]) optimizer = torch.optim.Adam(model.parameters(), lr=0.001, - weight_decay=5e-4) + weight_decay=5e-4) if rank == 0: prep_time = round(time.perf_counter() - wall_clock_start, 2) @@ -195,11 +198,9 @@ def test(loader: NodeLoader, num_steps: Optional[int] = None): eval_acc = test(val_loader, num_steps=eval_steps) if rank == 0: print(f"Val Accuracy: {eval_acc:.4f}%", ) - print( - f"Epoch {epoch:05d} | " - f"Accuracy {eval_acc:.4f} | " - f"Time {epoch_end - start:.2f}" - ) + print(f"Epoch {epoch:05d} | " + f"Accuracy {eval_acc:.4f} | " + f"Time {epoch_end - start:.2f}") acc.reset() dist.barrier() @@ -209,9 +210,11 @@ def test(loader: NodeLoader, num_steps: Optional[int] = None): print(f"Test Accuracy: {test_acc:.4f}%", ) dist.destroy_process_group() if dist.is_initialized() else None + if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('--mode', type=str, default='baseline', choices=['baseline', 'UVA-features', 'UVA']) + parser.add_argument('--mode', type=str, default='baseline', + choices=['baseline', 'UVA-features', 'UVA']) args = parser.parse_args() # Get the world size from the WORLD_SIZE variable or directly from SLURM: @@ -226,4 +229,3 @@ def test(loader: NodeLoader, num_steps: Optional[int] = None): device = torch.device(local_rank) torch.cuda.set_device(device) run(world_size, rank, local_rank, device, args.mode) - diff --git a/examples/distributed/wholegraph/feature_store.py b/examples/distributed/wholegraph/feature_store.py index e9778d21d041..c2d7ee9efd44 100644 --- a/examples/distributed/wholegraph/feature_store.py +++ b/examples/distributed/wholegraph/feature_store.py @@ -2,14 +2,19 @@ import torch import torch.distributed as dist +from nv_distributed_graph import ( + DistEmbedding, + DistTensor, + dist_shmem, + nvlink_network, +) import torch_geometric -from nv_distributed_graph import DistTensor, DistEmbedding, dist_shmem, nvlink_network - from torch_geometric.data.feature_store import FeatureStore, TensorAttr + class WholeGraphFeatureStore(FeatureStore): - r""" A high-performance, UVA-enabled, and multi-GPU/multi-node friendly feature store, powered by WholeGraph library. + r"""A high-performance, UVA-enabled, and multi-GPU/multi-node friendly feature store, powered by WholeGraph library. It is compatible with PyG's FeatureStore class and supports both homogeneous and heterogeneous graph data types. Args: @@ -34,7 +39,8 @@ class WholeGraphFeatureStore(FeatureStore): def __init__(self, pyg_data): r"""Initializes the WholeGraphFeatureStore class and loads features from torch_geometric.data.Data/HeteroData.""" super().__init__() - self._store = {} # A dictionary of tuple to hold the feature embeddings + self._store = { + } # A dictionary of tuple to hold the feature embeddings if dist_shmem.get_local_rank() == dist.get_rank(): self.backend = 'vmm' @@ -42,34 +48,46 @@ def __init__(self, pyg_data): self.backend = 'vmm' if nvlink_network() else 'nccl' if isinstance(pyg_data, torch_geometric.data.Data): - self.put_tensor(pyg_data['x'], group_name=None, attr_name='x', index=None) - self.put_tensor(pyg_data['y'], group_name=None, attr_name='y', index=None) + self.put_tensor(pyg_data['x'], group_name=None, attr_name='x', + index=None) + self.put_tensor(pyg_data['y'], group_name=None, attr_name='y', + index=None) - elif isinstance(pyg_data, torch_geometric.data.HeteroData): # if HeteroData, we need to handle differently + elif isinstance(pyg_data, torch_geometric.data.HeteroData + ): # if HeteroData, we need to handle differently for group_name, group in pyg_data.node_items(): for attr_name in group: - if group.is_node_attr(attr_name) and attr_name in {'x', 'y'}: - self.put_tensor(pyg_data[group_name][attr_name], group_name=group_name, attr_name=attr_name, index=None) + if group.is_node_attr(attr_name) and attr_name in { + 'x', 'y' + }: + self.put_tensor(pyg_data[group_name][attr_name], + group_name=group_name, + attr_name=attr_name, index=None) # This is a hack for MAG240M dataset, to add node features for 'institution' and 'author' nodes. # This should not be presented in the upstream code. elif attr_name == 'num_nodes': feature_dim = 768 num_nodes = group[attr_name] - shape=[num_nodes, feature_dim] - self[group_name, 'x', None] = DistEmbedding(shape=shape, dtype=torch.float16, device="cpu", backend=self.backend) + shape = [num_nodes, feature_dim] + self[group_name, 'x', + None] = DistEmbedding(shape=shape, + dtype=torch.float16, + device="cpu", + backend=self.backend) else: - raise TypeError("Expected pyg_data to be of type torch_geometric.data.Data or torch_geometric.data.HeteroData.") + raise TypeError( + "Expected pyg_data to be of type torch_geometric.data.Data or torch_geometric.data.HeteroData." + ) def _put_tensor(self, tensor: torch.Tensor, attr): - """ - Creates and stores features (either DistTensor or DistEmbedding) from the given tensor, + """Creates and stores features (either DistTensor or DistEmbedding) from the given tensor, using a key derived from the group and attribute name. Args: tensor (torch.Tensor): The tensor to be passed to the feature store. attr: PyG's TensorAttr to fully specify each feature store. """ - key = (attr.group_name, attr.attr_name) + key = (attr.group_name, attr.attr_name) out = self._store.get(key) if out is not None and attr.index is not None: out[attr.index] = tensor @@ -77,14 +95,18 @@ def _put_tensor(self, tensor: torch.Tensor, attr): assert attr.index is None if tensor.dim() == 1: # No need to unsqueeze if WholeGraph fix this https://github.com/rapidsai/wholegraph/pull/229 - self._store[key] = DistTensor(tensor.unsqueeze(1), device="cpu", backend=self.backend) + self._store[key] = DistTensor(tensor.unsqueeze(1), + device="cpu", + backend=self.backend) else: - self._store[key] = DistEmbedding(tensor, device="cpu", backend=self.backend) + self._store[key] = DistEmbedding(tensor, device="cpu", + backend=self.backend) return True - def _get_tensor(self, attr) -> Optional[Union[torch.Tensor, DistTensor, DistEmbedding]]: - """ - Retrieves a tensor based on the provided attribute. + def _get_tensor( + self, + attr) -> Optional[Union[torch.Tensor, DistTensor, DistEmbedding]]: + """Retrieves a tensor based on the provided attribute. Args: attr: An object containing the necessary attributes to fetch the tensor. @@ -92,8 +114,7 @@ def _get_tensor(self, attr) -> Optional[Union[torch.Tensor, DistTensor, DistEmbe Returns: A tensor which can be of type torch.Tensor, DistTensor, or DistEmbedding, or None if not found. """ - - key = (attr.group_name, attr.attr_name) + key = (attr.group_name, attr.attr_name) tensor = self._store.get(key) if tensor is not None: if attr.index is not None: @@ -112,5 +133,6 @@ def _get_tensor_size(self, attr): def get_all_tensor_attrs(self): r"""Obtains all feature attributes stored in `Data`.""" return [ - TensorAttr(group_name=group, attr_name=name) for group, name in self._store.keys() - ] \ No newline at end of file + TensorAttr(group_name=group, attr_name=name) + for group, name in self._store.keys() + ] diff --git a/examples/distributed/wholegraph/graph_store.py b/examples/distributed/wholegraph/graph_store.py index ac11c7d304c9..de2cbe11dba0 100644 --- a/examples/distributed/wholegraph/graph_store.py +++ b/examples/distributed/wholegraph/graph_store.py @@ -1,30 +1,31 @@ -from typing import List, Optional, Tuple from dataclasses import dataclass +from typing import List, Optional, Tuple import torch.distributed as dist +from nv_distributed_graph import DistGraphCSC, dist_shmem, nvlink_network import torch_geometric -from torch_geometric.data.graph_store import GraphStore, EdgeAttr, EdgeLayout +from torch_geometric.data.graph_store import EdgeAttr, EdgeLayout, GraphStore from torch_geometric.sampler import SamplerOutput +from torch_geometric.typing import EdgeType -from torch_geometric.typing import EdgeType - -from nv_distributed_graph import DistGraphCSC, dist_shmem, nvlink_network @dataclass class WholeGraphEdgeAttr(EdgeAttr): r"""Edge attribute class for WholeGraph GraphStore enforcing layout to be CSC.""" def __init__( self, - edge_type: Optional[EdgeType] = None, # use string to represent edge type for simplicity + edge_type: Optional[ + EdgeType] = None, # use string to represent edge type for simplicity is_sorted: bool = False, size: Optional[Tuple[int, int]] = None, ): - layout = EdgeLayout.CSC # Enforce CSC layout for WholeGraph for now + layout = EdgeLayout.CSC # Enforce CSC layout for WholeGraph for now super().__init__(edge_type, layout, is_sorted, size) + class WholeGraphGraphStore(GraphStore): - r""" A high-performance, UVA-enabled, and multi-GPU/multi-node friendly graph store, powered by WholeGraph library. + r"""A high-performance, UVA-enabled, and multi-GPU/multi-node friendly graph store, powered by WholeGraph library. It is compatible with PyG's GraphStore base class and supports both homogeneous and heterogeneous graph data types. Args: @@ -35,7 +36,8 @@ class WholeGraphGraphStore(GraphStore): """ def __init__(self, pyg_data, format='wholegraph'): super().__init__(edge_attr_cls=WholeGraphEdgeAttr) - self._g = {} # for simplicy, _g is a dictionary of DistGraphCSC to hold the graph structure data for each type + self._g = { + } # for simplicy, _g is a dictionary of DistGraphCSC to hold the graph structure data for each type if format == 'wholegraph': pinned_shared = False @@ -45,7 +47,7 @@ def __init__(self, pyg_data, format='wholegraph'): backend = 'vmm' if nvlink_network() else 'nccl' elif format == 'pyg': pinned_shared = True - backend = None # backend is a no-op for pyg format + backend = None # backend is a no-op for pyg format else: raise ValueError("Unsupported underlying graph format") @@ -55,7 +57,7 @@ def __init__(self, pyg_data, format='wholegraph'): if 'adj_t' not in pyg_data: row, col = None, None if dist_shmem.get_local_rank() == 0: - row, col, _ = pyg_data.csc() # discard permutation for now + row, col, _ = pyg_data.csc() # discard permutation for now row = dist_shmem.to_shmem(row) col = dist_shmem.to_shmem(col) size = pyg_data.size() @@ -75,7 +77,8 @@ def __init__(self, pyg_data, format='wholegraph'): ) self.put_adj_t(graph, size=size) - elif isinstance(pyg_data, torch_geometric.data.HeteroData): # hetero graph + elif isinstance(pyg_data, + torch_geometric.data.HeteroData): # hetero graph # issue: this will crash: pyg_data.get_all_edge_attrs()[0] if pyg_data is a torch sparse csr # walkaround: self.num_nodes = pyg_data.num_nodes @@ -83,7 +86,8 @@ def __init__(self, pyg_data, format='wholegraph'): if 'adj_t' not in edge_store: row, col = None, None if dist_shmem.get_local_rank() == 0: - row, col, _ = edge_store.csc() # discard permutation for now + row, col, _ = edge_store.csc( + ) # discard permutation for now row = dist_shmem.to_shmem(row) col = dist_shmem.to_shmem(col) size = edge_store.size() @@ -123,7 +127,8 @@ def get_adj_t(self, *args, **kwargs) -> DistGraphCSC: raise KeyError(f"'adj_t' for '{edge_attr}' not found") return graph_adj_t - def _put_adj_t(self, adj_t: DistGraphCSC, edge_attr: WholeGraphEdgeAttr) -> bool: + def _put_adj_t(self, adj_t: DistGraphCSC, + edge_attr: WholeGraphEdgeAttr) -> bool: if not hasattr(self, '_edge_attrs'): self._edge_attrs = {} self._edge_attrs[edge_attr.edge_type] = edge_attr @@ -135,7 +140,8 @@ def _put_adj_t(self, adj_t: DistGraphCSC, edge_attr: WholeGraphEdgeAttr) -> bool edge_attr.size = adj_t.size return True - def _get_adj_t(self, edge_attr: WholeGraphEdgeAttr) -> Optional[DistGraphCSC]: + def _get_adj_t(self, + edge_attr: WholeGraphEdgeAttr) -> Optional[DistGraphCSC]: store = self._g.get(edge_attr.edge_type) edge_attrs = getattr(self, '_edge_attrs', {}) edge_attr = edge_attrs[edge_attr.edge_type] @@ -164,8 +170,7 @@ def get_all_edge_attrs(self) -> List[WholeGraphEdgeAttr]: edge_attrs = getattr(self, '_edge_attrs', {}) for key, store in self._g.items(): if key not in edge_attrs: - edge_attrs[key] = WholeGraphEdgeAttr( - key, size=store.size) + edge_attrs[key] = WholeGraphEdgeAttr(key, size=store.size) return list(edge_attrs.values()) def csc(self): @@ -173,7 +178,7 @@ def csc(self): if not self.is_hetero: key = self.get_all_edge_attrs()[0] store = self._get_adj_t(key) - return store.row_indx, store.col_ptrs, None # no permutation vector + return store.row_indx, store.col_ptrs, None # no permutation vector else: row_dict = {} col_dict = {} @@ -205,8 +210,9 @@ def create_pyg_subgraph(WG_SampleOutput) -> Tuple: num_sampled_nodes = [] node = sampled_nodes_list[0] - for hop in range(len(sampled_nodes_list)-1): - sampled_nodes = len(sampled_nodes_list[hop]) - len(sampled_nodes_list[hop+1]) + for hop in range(len(sampled_nodes_list) - 1): + sampled_nodes = len(sampled_nodes_list[hop]) - len( + sampled_nodes_list[hop + 1]) num_sampled_nodes.append(sampled_nodes) num_sampled_nodes.append(len(sampled_nodes_list[-1])) num_sampled_nodes.reverse() @@ -215,10 +221,12 @@ def create_pyg_subgraph(WG_SampleOutput) -> Tuple: num_sampled_edges = [len(csr_col_ind_list[-1])] # Loop in reverse order, starting from the second last layer for layer in range(layers - 2, -1, -1): - num_sampled_edges.append(len(csr_col_ind_list[layer] - len(csr_col_ind_list[layer + 1]))) + num_sampled_edges.append( + len(csr_col_ind_list[layer] - + len(csr_col_ind_list[layer + 1]))) row = csr_col_ind_list[0] # rows - col = edge_indice_list[0][1] # dst node + col = edge_indice_list[0][1] # dst node edge = None batch = None diff --git a/examples/distributed/wholegraph/nv_distributed_graph/__init__.py b/examples/distributed/wholegraph/nv_distributed_graph/__init__.py index 6bf728997b0e..b75920a2dc23 100644 --- a/examples/distributed/wholegraph/nv_distributed_graph/__init__.py +++ b/examples/distributed/wholegraph/nv_distributed_graph/__init__.py @@ -1,4 +1,4 @@ from .dist_graph import DistGraphCSC from .dist_tensor import DistTensor, DistEmbedding from .dist_shmem import init_process_group_per_node, get_local_process_group, get_local_root, get_local_rank, get_local_size, to_shmem -from .wholegraph import nvlink_network \ No newline at end of file +from .wholegraph import nvlink_network diff --git a/examples/distributed/wholegraph/nv_distributed_graph/dist_graph.py b/examples/distributed/wholegraph/nv_distributed_graph/dist_graph.py index 65f4e0314362..604defcb2072 100644 --- a/examples/distributed/wholegraph/nv_distributed_graph/dist_graph.py +++ b/examples/distributed/wholegraph/nv_distributed_graph/dist_graph.py @@ -1,13 +1,13 @@ -from typing import Any, List, Union, Literal, Optional -import numpy as np +from typing import Any, List, Literal, Optional, Union +import numpy as np import torch -from . import dist_shmem -from . import dist_tensor +from . import dist_shmem, dist_tensor + class DistGraphCSC: - """ Distributed Graph Store based on DistTensors for Compressed Sparse Column (CSC) format. + """Distributed Graph Store based on DistTensors for Compressed Sparse Column (CSC) format. Only support homogeneous graph for now. Parameters ---------- @@ -18,19 +18,22 @@ class DistGraphCSC: def __init__( self, col_ptrs_src: Optional[Union[torch.Tensor, str, List[str]]] = None, - row_indx_src : Optional[Union[torch.Tensor, str, List[str]]] = None, + row_indx_src: Optional[Union[torch.Tensor, str, List[str]]] = None, device: Optional[Literal["cpu", "cuda"]] = "cpu", pinned_shared: Optional[bool] = False, - partition_book: Optional[Union[List[int], None]] = None, # location memtype ?? backend?? ; engine; comm = vmm/nccl .. - backend: Optional[str] = "nccl", # reserved this for future use + partition_book: Optional[Union[ + List[int], + None]] = None, # location memtype ?? backend?? ; engine; comm = vmm/nccl .. + backend: Optional[str] = "nccl", # reserved this for future use *args, **kwargs, ): - # optionally to save node/edge feature tensors (view) - self.data = {} # place holder for the hetergenous graph + # optionally to save node/edge feature tensors (view) + self.data = {} # place holder for the hetergenous graph self.device = device if partition_book is not None: - raise NotImplementedError("Uneven partition of 1-D disttensor is not turned on yet.") + raise NotImplementedError( + "Uneven partition of 1-D disttensor is not turned on yet.") if pinned_shared: dist_shmem.init_process_group_per_node() @@ -38,15 +41,20 @@ def __init__( col_ptrs = None row_indx = None if dist_shmem.get_local_rank() == 0: - if isinstance(col_ptrs_src, torch.Tensor) and isinstance(row_indx_src, torch.Tensor): + if isinstance(col_ptrs_src, torch.Tensor) and isinstance( + row_indx_src, torch.Tensor): col_ptrs = col_ptrs_src row_indx = row_indx_src - elif col_ptrs_src.endswith('.pt') and row_indx_src.endswith('.pt'): + elif col_ptrs_src.endswith('.pt') and row_indx_src.endswith( + '.pt'): col_ptrs = torch.load(col_ptrs_src, mmap=True) row_indx = torch.load(row_indx_src, mmap=True) - elif col_ptrs_src.endswith('.npy') and row_indx_src.endswith('.npy'): - col_ptrs = torch.from_numpy(np.load(col_ptrs_src, mmap_mode='c')) - row_indx = torch.from_numpy(np.load(row_indx_src, mmap_mode='c')) + elif col_ptrs_src.endswith('.npy') and row_indx_src.endswith( + '.npy'): + col_ptrs = torch.from_numpy( + np.load(col_ptrs_src, mmap_mode='c')) + row_indx = torch.from_numpy( + np.load(row_indx_src, mmap_mode='c')) else: raise ValueError("Unsupported file format.") @@ -54,8 +62,10 @@ def __init__( self.row_indx = dist_shmem.to_shmem(row_indx) else: # 2-gather approach here only - self.col_ptrs = dist_tensor.DistTensor(col_ptrs_src, device = device, backend = backend) - self.row_indx = dist_tensor.DistTensor(row_indx_src, device = device, backend = backend) + self.col_ptrs = dist_tensor.DistTensor(col_ptrs_src, device=device, + backend=backend) + self.row_indx = dist_tensor.DistTensor(row_indx_src, device=device, + backend=backend) @property def num_nodes(self): @@ -65,7 +75,7 @@ def num_nodes(self): def num_edges(self): return self.row_indx.shape[0] - def __getitem__ (self, name: str) -> Any: + def __getitem__(self, name: str) -> Any: return self.data[name] def __setitem__(self, name: str, value: Any) -> None: @@ -78,7 +88,6 @@ def transform_nodes(self, nodes): Args: nodes (_type_): _description_ """ - pass def transform_edges(self, edges): """Transform all seed edges from every rank to the local seed edges @@ -86,7 +95,6 @@ def transform_edges(self, edges): Args: edges (_type_): _description_ """ - pass - def transform_graph() : #back to graph + def transform_graph(): #back to graph pass diff --git a/examples/distributed/wholegraph/nv_distributed_graph/dist_shmem.py b/examples/distributed/wholegraph/nv_distributed_graph/dist_shmem.py index 49dc6ac56a64..fd1e5ec23f76 100644 --- a/examples/distributed/wholegraph/nv_distributed_graph/dist_shmem.py +++ b/examples/distributed/wholegraph/nv_distributed_graph/dist_shmem.py @@ -1,5 +1,6 @@ -"""Utilities for launching distributed GNN tasks. """ +"""Utilities for launching distributed GNN tasks.""" import os + import torch.distributed as dist import torch.multiprocessing as mp @@ -8,13 +9,15 @@ _LOCAL_ROOT_AUTH_KEY = None nprocs_per_node = 1 + def init_process_group_per_node(): """Initialize the distributed process group for each node.""" if _LOCAL_PROCESS_GROUP is None: create_process_group_per_node() return else: - assert dist.get_process_group_ranks(_LOCAL_PROCESS_GROUP)[0] == _LOCAL_ROOT_GLOBAL_RANK + assert dist.get_process_group_ranks( + _LOCAL_PROCESS_GROUP)[0] == _LOCAL_ROOT_GLOBAL_RANK return @@ -23,9 +26,12 @@ def create_process_group_per_node(): global _LOCAL_PROCESS_GROUP, _LOCAL_ROOT_GLOBAL_RANK global nprocs_per_node assert _LOCAL_PROCESS_GROUP is None and _LOCAL_ROOT_GLOBAL_RANK is None - assert dist.is_initialized(), "torch.distributed is not initialized. Please call torch.distributed.init_process_group() first." + assert dist.is_initialized( + ), "torch.distributed is not initialized. Please call torch.distributed.init_process_group() first." - nprocs_per_node = int(os.environ.get('LOCAL_WORLD_SIZE', os.environ.get('SLURM_NTASKS_PER_NODE'))) + nprocs_per_node = int( + os.environ.get('LOCAL_WORLD_SIZE', + os.environ.get('SLURM_NTASKS_PER_NODE'))) world_size = dist.get_world_size() rank = dist.get_rank() if dist.is_initialized() else 0 assert world_size % nprocs_per_node == 0 @@ -33,12 +39,14 @@ def create_process_group_per_node(): num_nodes = world_size // nprocs_per_node node_id = rank // nprocs_per_node for i in range(num_nodes): - node_ranks = list(range(i * nprocs_per_node, (i + 1) * nprocs_per_node)) + node_ranks = list(range(i * nprocs_per_node, + (i + 1) * nprocs_per_node)) pg = dist.new_group(node_ranks) if i == node_id: _LOCAL_PROCESS_GROUP = pg assert _LOCAL_PROCESS_GROUP is not None - _LOCAL_ROOT_GLOBAL_RANK = dist.get_process_group_ranks(_LOCAL_PROCESS_GROUP)[0] + _LOCAL_ROOT_GLOBAL_RANK = dist.get_process_group_ranks( + _LOCAL_PROCESS_GROUP)[0] def get_local_process_group(): @@ -104,7 +112,7 @@ def to_shmem(dataset): The objects can be DGLGraph and Pytorch Tensor, or any customized objects with the same mechanism of using shared memory during pickling process. - Returns + Returns: ------- dataset : Reconstructed dataset in shared memory Returned dataset preserves the same object hierarchy of the input. @@ -127,5 +135,7 @@ def to_shmem(dataset): if dist.get_rank() != local_root: # only non-root process performs pickle.loads() dataset = mp.reductions.ForkingPickler.loads(handle) - dist.barrier(group=local_group) # necessary to prevent unexpected close of any procs beyond this function + dist.barrier( + group=local_group + ) # necessary to prevent unexpected close of any procs beyond this function return dataset diff --git a/examples/distributed/wholegraph/nv_distributed_graph/dist_tensor.py b/examples/distributed/wholegraph/nv_distributed_graph/dist_tensor.py index dc67efb2f8c3..ba329f7c7613 100644 --- a/examples/distributed/wholegraph/nv_distributed_graph/dist_tensor.py +++ b/examples/distributed/wholegraph/nv_distributed_graph/dist_tensor.py @@ -2,12 +2,17 @@ from typing import List, Literal, Optional, Union import numpy as np +import pylibwholegraph import torch -import torch.distributed as dist -import pylibwholegraph -from .wholegraph import create_wg_dist_tensor, create_wg_dist_tensor_from_files, finalize_wholegraph, _wm_global -from .wholegraph import copy_host_global_tensor_to_local +from .wholegraph import ( + _wm_global, + copy_host_global_tensor_to_local, + create_wg_dist_tensor, + create_wg_dist_tensor_from_files, + finalize_wholegraph, +) + class DistTensor: _instance_count = 0 @@ -39,7 +44,9 @@ def __init__( shape: Optional[Union[list, tuple]] = None, dtype: Optional[torch.dtype] = None, device: Optional[Literal["cpu", "cuda"]] = "cpu", - partition_book: Optional[Union[List[int], None]] = None, # location memtype ?? backend?? ; engine; comm = vmm/nccl .. + partition_book: Optional[Union[ + List[int], + None]] = None, # location memtype ?? backend?? ; engine; comm = vmm/nccl .. backend: Optional[str] = "nccl", *args, **kwargs, @@ -48,14 +55,17 @@ def __init__( # Register the cleanup function for safty exit atexit.register(finalize_wholegraph) - self._tensor = None # WholeMemory tensor for now. In future, we may support other types of distributed tensors. + self._tensor = None # WholeMemory tensor for now. In future, we may support other types of distributed tensors. self._device = device if src is None: # Create an empty WholeGraph tensor assert shape is not None, "Please specify the shape of the tensor." assert dtype is not None, "Please specify the dtype of the tensor." - assert len(shape) == 1 or len(shape) == 2, "The shape of the tensor must be 1D or 2D." - self._tensor = create_wg_dist_tensor(list(shape), dtype, device, partition_book, backend, *args, **kwargs) + assert len(shape) == 1 or len( + shape) == 2, "The shape of the tensor must be 1D or 2D." + self._tensor = create_wg_dist_tensor(list(shape), dtype, device, + partition_book, backend, + *args, **kwargs) self._dtype = dtype else: if isinstance(src, list): @@ -63,33 +73,44 @@ def __init__( # Only support the binary file format directly loaded via WM API for now # TODO (@liuc): support merging multiple pt or npy files to create a tensor assert shape is not None and dtype is not None, "For now, read from multiple files are only supported in binary format." - self._tensor = create_wg_dist_tensor_from_files(src, shape, dtype, device, partition_book, backend, *args, **kwargs) + self._tensor = create_wg_dist_tensor_from_files( + src, shape, dtype, device, partition_book, backend, *args, + **kwargs) #self._tensor.from_filelist(src) self._dtype = dtype else: if isinstance(src, torch.Tensor): - self._tensor = create_wg_dist_tensor(list(src.shape), src.dtype, device, partition_book, backend, *args, **kwargs) + self._tensor = create_wg_dist_tensor( + list(src.shape), src.dtype, device, partition_book, + backend, *args, **kwargs) self._dtype = src.dtype host_tensor = src elif isinstance(src, str) and src.endswith('.pt'): host_tensor = torch.load(src, mmap=True) - self._tensor = create_wg_dist_tensor(list(host_tensor.shape), host_tensor.dtype, device, partition_book, backend, *args, **kwargs) + self._tensor = create_wg_dist_tensor( + list(host_tensor.shape), host_tensor.dtype, device, + partition_book, backend, *args, **kwargs) self._dtype = host_tensor.dtype elif isinstance(src, str) and src.endswith('.npy'): host_tensor = torch.from_numpy(np.load(src, mmap_mode='c')) self._dtype = host_tensor.dtype - self._tensor = create_wg_dist_tensor(list(host_tensor.shape), host_tensor.dtype, device, partition_book, backend, *args, **kwargs) + self._tensor = create_wg_dist_tensor( + list(host_tensor.shape), host_tensor.dtype, device, + partition_book, backend, *args, **kwargs) else: - raise ValueError("Unsupported source type. Please provide a torch.Tensor, a file path, or a list of file paths.") + raise ValueError( + "Unsupported source type. Please provide a torch.Tensor, a file path, or a list of file paths." + ) self.load_from_global_tensor(host_tensor) - DistTensor._instance_count += 1 # increase the instance count to track for resource cleanup + DistTensor._instance_count += 1 # increase the instance count to track for resource cleanup def load_from_global_tensor(self, tensor): # input pytorch host tensor (mmapped or in shared host memory), and copy to wholegraph tensor assert self._tensor is not None, "Please create WholeGraph tensor first." self._dtype = tensor.dtype - if isinstance(self._tensor, pylibwholegraph.torch.WholeMemoryEmbedding): + if isinstance(self._tensor, + pylibwholegraph.torch.WholeMemoryEmbedding): _tensor = self._tensor.get_embedding_tensor() else: _tensor = self._tensor @@ -100,16 +121,19 @@ def load_from_local_tensor(self, tensor): assert self._tensor is not None, "Please create WholeGraph tensor first." assert self._tensor.local_shape == tensor.shape, "The shape of the tensor does not match the shape of the local tensor." assert self._dtype == tensor.dtype, "The dtype of the tensor does not match the dtype of the local tensor." - if isinstance(self._tensor, pylibwholegraph.torch.WholeMemoryEmbedding): - self._tensor.get_embedding_tensor().get_local_tensor().copy_(tensor) + if isinstance(self._tensor, + pylibwholegraph.torch.WholeMemoryEmbedding): + self._tensor.get_embedding_tensor().get_local_tensor().copy_( + tensor) else: self._tensor.get_local_tensor().copy_(tensor) - @classmethod - def from_tensor(cls, tensor: torch.Tensor, device: Optional[Literal["cpu", "cuda"]] = "cpu", partition_book: Union[List[int], None] = None, backend: Optional[str] = 'nccl'): - """ - Create a WholeGraph-backed Distributed Tensor from a PyTorch tensor. + def from_tensor(cls, tensor: torch.Tensor, + device: Optional[Literal["cpu", "cuda"]] = "cpu", + partition_book: Union[List[int], None] = None, + backend: Optional[str] = 'nccl'): + """Create a WholeGraph-backed Distributed Tensor from a PyTorch tensor. Parameters ---------- @@ -120,17 +144,20 @@ def from_tensor(cls, tensor: torch.Tensor, device: Optional[Literal["cpu", "cuda backend : str, optional The backend used for communication. Default is "nccl". - Returns + Returns: ------- DistTensor The WholeGraph-backed Distributed Tensor. """ - return cls(src=tensor, device=device, partition_book=partition_book, backend=backend) + return cls(src=tensor, device=device, partition_book=partition_book, + backend=backend) @classmethod - def from_file(cls, file_path: str, device: Optional[Literal["cpu", "cuda"]] = "cpu", partition_book: Union[List[int], None] = None, backend: Optional[str] = 'nccl'): - """ - Create a WholeGraph-backed Distributed Tensor from a file. + def from_file(cls, file_path: str, + device: Optional[Literal["cpu", "cuda"]] = "cpu", + partition_book: Union[List[int], None] = None, + backend: Optional[str] = 'nccl'): + """Create a WholeGraph-backed Distributed Tensor from a file. Parameters ---------- @@ -141,17 +168,16 @@ def from_file(cls, file_path: str, device: Optional[Literal["cpu", "cuda"]] = "c backend : str, optional The backend used for communication. Default is "nccl". - Returns + Returns: ------- DistTensor The WholeGraph-backed Distributed Tensor. """ - return cls(src=file_path, device=device, partition_book=partition_book, backend=backend) - + return cls(src=file_path, device=device, partition_book=partition_book, + backend=backend) def __setitem__(self, idx: torch.Tensor, val: torch.Tensor): - """ - Set the embeddings for the specified node indices. + """Set the embeddings for the specified node indices. This call must be called by all processes. Parameters @@ -170,41 +196,41 @@ def __setitem__(self, idx: torch.Tensor, val: torch.Tensor): self._tensor.scatter(val, idx) def __getitem__(self, idx: torch.Tensor) -> torch.Tensor: - """ - Get the embeddings for the specified node indices (remotely). + """Get the embeddings for the specified node indices (remotely). This call must be called by all processes. Parameters ---------- idx : torch.Tensor Index of the embeddings to collect. - Returns + + Returns: ------- torch.Tensor The requested node embeddings. """ assert self._tensor is not None, "Please create WholeGraph tensor first." idx = idx.cuda() - output_tensor = self._tensor.gather(idx) # output_tensor is on cuda by default + output_tensor = self._tensor.gather( + idx) # output_tensor is on cuda by default return output_tensor def get_local_tensor(self, host_view=False): - """ - Get the local embedding tensor and its element offset at current rank. + """Get the local embedding tensor and its element offset at current rank. - Returns + Returns: ------- (torch.Tensor, int) Tuple of local torch Tensor (converted from DLPack) and its offset. """ - local_tensor, offset = self._tensor.get_local_tensor(host_view = host_view) + local_tensor, offset = self._tensor.get_local_tensor( + host_view=host_view) return local_tensor def get_local_offset(self): - """ - Get the local embedding tensor and its element offset at current rank. + """Get the local embedding tensor and its element offset at current rank. - Returns + Returns: ------- (torch.Tensor, int) Tuple of local torch Tensor (converted from DLPack) and its offset. @@ -213,10 +239,9 @@ def get_local_offset(self): return offset def get_comm(self): - """ - Get the communicator of the WholeGraph embedding. + """Get the communicator of the WholeGraph embedding. - Returns + Returns: ------- WholeMemoryCommunicator The WholeGraph global communicator of the WholeGraph embedding. @@ -255,9 +280,9 @@ def __del__(self): if DistTensor._instance_count == 0: finalize_wholegraph() + class DistEmbedding(DistTensor): - """ - WholeGraph-backed Distributed Embedding Interface for PyTorch. + """WholeGraph-backed Distributed Embedding Interface for PyTorch. Parameters ---------- src: Optional[Union[torch.Tensor, str, List[str]]] @@ -294,30 +319,26 @@ def __init__( device: Optional[Literal["cpu", "cuda"]] = "cpu", partition_book: Union[List[int], None] = None, backend: Optional[str] = "nccl", - cache_policy = None, #Optional[pylibwholegraph.WholeMemoryCachePolicy] = None, + cache_policy=None, #Optional[pylibwholegraph.WholeMemoryCachePolicy] = None, gather_sms: Optional[int] = -1, round_robin_size: int = 0, name: Optional[str] = None, ): self._name = name - super().__init__(src, shape, dtype, device, partition_book, backend, cache_policy=cache_policy, gather_sms=gather_sms, round_robin_size=round_robin_size) - self._embedding = self._tensor # returned _tensor is a WmEmbedding object + super().__init__(src, shape, dtype, device, partition_book, backend, + cache_policy=cache_policy, gather_sms=gather_sms, + round_robin_size=round_robin_size) + self._embedding = self._tensor # returned _tensor is a WmEmbedding object self._tensor = self._embedding.get_embedding_tensor() @classmethod - def from_tensor( - cls, - tensor: torch.Tensor, - device: Literal["cpu", "cuda"] = "cpu", - partition_book: Union[List[int], None] = None, - name: Optional[str] = None, - cache_policy = None, - *args, - **kwargs - ): - """ - Create a WholeGraph-backed Distributed Embedding (hooked with PyT's grad tracing) from a PyTorch tensor. + def from_tensor(cls, tensor: torch.Tensor, device: Literal["cpu", + "cuda"] = "cpu", + partition_book: Union[List[int], None] = None, + name: Optional[str] = None, cache_policy=None, *args, + **kwargs): + """Create a WholeGraph-backed Distributed Embedding (hooked with PyT's grad tracing) from a PyTorch tensor. Parameters ---------- @@ -328,26 +349,20 @@ def from_tensor( name : str, optional The name of the tensor. - Returns + Returns: ------- DistEmbedding The WholeGraph-backed Distributed Tensor. """ - return cls(tensor, device, partition_book, name, cache_policy, *args, **kwargs) + return cls(tensor, device, partition_book, name, cache_policy, *args, + **kwargs) @classmethod - def from_file( - cls, - file_path: str, - device: Literal["cpu", "cuda"] = "cpu", - partition_book: Union[List[int], None] = None, - name: Optional[str] = None, - cache_policy = None, - *args, - **kwargs - ): - """ - Create a WholeGraph-backed Distributed Tensor from a file. + def from_file(cls, file_path: str, device: Literal["cpu", "cuda"] = "cpu", + partition_book: Union[List[int], None] = None, + name: Optional[str] = None, cache_policy=None, *args, + **kwargs): + """Create a WholeGraph-backed Distributed Tensor from a file. Parameters ---------- @@ -358,17 +373,16 @@ def from_file( name : str, optional The name of the tensor. - Returns + Returns: ------- DistTensor The WholeGraph-backed Distributed Tensor. """ - return cls(file_path, device, partition_book, name, cache_policy, *args, **kwargs) - + return cls(file_path, device, partition_book, name, cache_policy, + *args, **kwargs) def __setitem__(self, idx: torch.Tensor, val: torch.Tensor): - """ - Set the embeddings for the specified node indices. + """Set the embeddings for the specified node indices. This call must be called by all processes. Parameters @@ -387,22 +401,23 @@ def __setitem__(self, idx: torch.Tensor, val: torch.Tensor): self._embedding.get_embedding_tensor().scatter(val, idx) def __getitem__(self, idx: torch.Tensor) -> torch.Tensor: - """ - Get the embeddings for the specified node indices (remotely). + """Get the embeddings for the specified node indices (remotely). This call must be called by all processes. Parameters ---------- idx : torch.Tensor Index of the embeddings to collect. - Returns + + Returns: ------- torch.Tensor The requested node embeddings. """ assert self._tensor is not None, "Please create WholeGraph tensor first." idx = idx.cuda() - output_tensor = self._embedding.gather(idx) # output_tensor is on cuda by default + output_tensor = self._embedding.gather( + idx) # output_tensor is on cuda by default return output_tensor @property @@ -421,4 +436,4 @@ def __repr__(self): return tensor_repr def __del__(self): - super().__del__() \ No newline at end of file + super().__del__() diff --git a/examples/distributed/wholegraph/nv_distributed_graph/wholegraph.py b/examples/distributed/wholegraph/nv_distributed_graph/wholegraph.py index 6e44538dc150..f792c84dcee3 100644 --- a/examples/distributed/wholegraph/nv_distributed_graph/wholegraph.py +++ b/examples/distributed/wholegraph/nv_distributed_graph/wholegraph.py @@ -1,13 +1,14 @@ -from typing import Tuple, List, Union +from typing import List, Tuple, Union +import pylibwholegraph.torch as wgth import torch import torch.distributed as dist -import pylibwholegraph.torch as wgth from . import dist_shmem _wm_global = False + def init_wholegraph(): global _wm_global @@ -17,13 +18,17 @@ def init_wholegraph(): local_size = dist_shmem.get_local_size() local_rank = dist_shmem.get_local_rank() - wgth.init(dist.get_rank(), dist.get_world_size(), local_rank, local_size=local_size, wm_log_level="info") - print(f"[Rank {dist.get_rank()}] WholeGraph Initialization: " - f"{dist.get_world_size()} GPUs are used with {local_size} GPUs per node.") + wgth.init(dist.get_rank(), dist.get_world_size(), local_rank, + local_size=local_size, wm_log_level="info") + print( + f"[Rank {dist.get_rank()}] WholeGraph Initialization: " + f"{dist.get_world_size()} GPUs are used with {local_size} GPUs per node." + ) global_comm = wgth.comm.get_global_communicator("nccl") _wm_global = True return global_comm + def finalize_wholegraph(): global _wm_global if _wm_global is False: @@ -31,9 +36,9 @@ def finalize_wholegraph(): wgth.finalize() _wm_global = False + def nvlink_network(): - r""" - Check if the current hardware supports cross-node NVLink network. + r"""Check if the current hardware supports cross-node NVLink network. """ if not _wm_global: raise RuntimeError("WholeGraph is not initialized.") @@ -56,20 +61,19 @@ def nvlink_network(): return False -def copy_host_global_tensor_to_local( - wm_tensor, host_tensor, wm_comm -): - local_tensor, local_start = wm_tensor.get_local_tensor( - host_view=False - ) + +def copy_host_global_tensor_to_local(wm_tensor, host_tensor, wm_comm): + local_tensor, local_start = wm_tensor.get_local_tensor(host_view=False) ## enable these checks when the wholegraph is updated to 24.10 #local_ref_start = wm_tensor.get_local_entry_start() #local_ref_count = wm_tensor.get_local_entry_count() #assert local_start == local_ref_start #assert local_tensor.shape[0] == local_ref_count - local_tensor.copy_(host_tensor[local_start : local_start + local_tensor.shape[0]]) + local_tensor.copy_(host_tensor[local_start:local_start + + local_tensor.shape[0]]) wm_comm.barrier() + def create_pyg_subgraph(WG_SampleOutput) -> Tuple: # PyG_SampleOutput (node, row, col, edge, batch...): # node (torch.Tensor): The sampled nodes in the original graph. @@ -85,8 +89,9 @@ def create_pyg_subgraph(WG_SampleOutput) -> Tuple: num_sampled_nodes = [] node = sampled_nodes_list[0] - for hop in range(len(sampled_nodes_list)-1): - sampled_nodes = len(sampled_nodes_list[hop]) - len(sampled_nodes_list[hop+1]) + for hop in range(len(sampled_nodes_list) - 1): + sampled_nodes = len(sampled_nodes_list[hop]) - len( + sampled_nodes_list[hop + 1]) num_sampled_nodes.append(sampled_nodes) num_sampled_nodes.append(len(sampled_nodes_list[-1])) num_sampled_nodes.reverse() @@ -95,32 +100,36 @@ def create_pyg_subgraph(WG_SampleOutput) -> Tuple: num_sampled_edges = [len(csr_col_ind_list[-1])] # Loop in reverse order, starting from the second last layer for layer in range(layers - 2, -1, -1): - num_sampled_edges.append(len(csr_col_ind_list[layer] - len(csr_col_ind_list[layer + 1]))) + num_sampled_edges.append( + len(csr_col_ind_list[layer] - len(csr_col_ind_list[layer + 1]))) row = csr_col_ind_list[0] # rows - col = edge_indice_list[0][1] # dst node + col = edge_indice_list[0][1] # dst node edge = None batch = None return node, row, col, edge, batch, num_sampled_nodes, num_sampled_edges + def sample_nodes_wmb_fn(wg_sampler, seeds, fanouts): # WG_SampleOutput (target_gids, edge_indice, csr_row_ptr, csr_col_ind): # target_gids [1D tensors]: unique sampled global node ids for each layer # edge_indice [rank-2 tensors]: edge list [src, des] for each layer, in local node id (start from 0 for each layer) to confirm # csr_row_ptr [1D tensors]: csr row ptrs for each subgraph in each layer (starting from 0 for each subgraph) # csr_col_ind [1D tensors]: csr col indx for each subgraph in each layer - WG_SampleOutput = wg_sampler.multilayer_sample_without_replacement(seeds, fanouts, None) + WG_SampleOutput = wg_sampler.multilayer_sample_without_replacement( + seeds, fanouts, None) return create_pyg_subgraph(WG_SampleOutput) + def create_wg_dist_tensor( - shape: list, - dtype: torch.dtype, - location: str = "cpu", - partition_book: Union[List[int], None] = None, # default is even partition - backend: str = "nccl", # default is nccl; support nccl, vmm, nvshmem... - **kwargs -): + shape: list, + dtype: torch.dtype, + location: str = "cpu", + partition_book: Union[List[int], + None] = None, # default is even partition + backend: str = "nccl", # default is nccl; support nccl, vmm, nvshmem... + **kwargs): """Create a WholeGraph-managed distributed tensor. Parameters @@ -161,13 +170,14 @@ def create_wg_dist_tensor( embedding_wholememory_location, dtype, shape, - cache_policy=cache_policy, # disable cache for now + cache_policy=cache_policy, # disable cache for now #embedding_entry_partition=partition_book, **kwargs #tensor_entry_partition=None # important to do load balance ) else: - assert len(shape) == 2 or len(shape) == 1, "The shape of the tensor must be 2D or 1D." + assert len(shape) == 2 or len( + shape) == 1, "The shape of the tensor must be 2D or 1D." wm_embedding = wgth.create_wholememory_tensor( global_comm, embedding_wholememory_type, @@ -179,15 +189,16 @@ def create_wg_dist_tensor( ) return wm_embedding + def create_wg_dist_tensor_from_files( - file_list: List[str], - shape: list, - dtype: torch.dtype, - location: str = "cpu", - partition_book: Union[List[int], None] = None, # default is even partition - backend: str = "nccl", # default is nccl; support nccl, vmm, nvshmem... - **kwargs -): + file_list: List[str], + shape: list, + dtype: torch.dtype, + location: str = "cpu", + partition_book: Union[List[int], + None] = None, # default is even partition + backend: str = "nccl", # default is nccl; support nccl, vmm, nvshmem... + **kwargs): """Create a WholeGraph-managed distributed tensor from a list of files. Parameters @@ -231,12 +242,12 @@ def create_wg_dist_tensor_from_files( file_list, dtype, shape[1], - cache_policy=cache_policy, # disable cache for now + cache_policy=cache_policy, # disable cache for now #embedding_entry_partition=partition_book, - **kwargs - ) + **kwargs) else: - assert len(shape) == 2 or len(shape) == 1, "The shape of the tensor must be 2D or 1D." + assert len(shape) == 2 or len( + shape) == 1, "The shape of the tensor must be 2D or 1D." last_dim_size = 0 if len(shape) == 1 else shape[1] wm_embedding = wgth.create_wholememory_tensor_from_filelist( global_comm, @@ -247,4 +258,4 @@ def create_wg_dist_tensor_from_files( last_dim_size, #tensor_entry_partition=partition_book # important to do load balance ) - return wm_embedding \ No newline at end of file + return wm_embedding diff --git a/examples/distributed/wholegraph/papers100m_dist_wholegraph_nc.py b/examples/distributed/wholegraph/papers100m_dist_wholegraph_nc.py index 9de67785c289..a09ace9c4d1d 100644 --- a/examples/distributed/wholegraph/papers100m_dist_wholegraph_nc.py +++ b/examples/distributed/wholegraph/papers100m_dist_wholegraph_nc.py @@ -13,26 +13,24 @@ import torch import torch.distributed as dist import torch.nn.functional as F +from feature_store import WholeGraphFeatureStore +from graph_store import WholeGraphGraphStore +from nv_distributed_graph import dist_shmem from ogb.nodeproppred import PygNodePropPredDataset from torch.nn.parallel import DistributedDataParallel from torchmetrics import Accuracy +import torch_geometric.transforms as T from torch_geometric.loader import NodeLoader from torch_geometric.nn import GCN - -import torch_geometric.transforms as T from torch_geometric.sampler import BaseSampler -from nv_distributed_graph import dist_shmem -from feature_store import WholeGraphFeatureStore -from graph_store import WholeGraphGraphStore - class WholeGraphSampler(BaseSampler): - r""" - A naive sampler class for WholeGraph graph storage that only supports uniform node-based sampling on homogeneous graph. + r"""A naive sampler class for WholeGraph graph storage that only supports uniform node-based sampling on homogeneous graph. """ - from torch_geometric.sampler import SamplerOutput, NodeSamplerInput + from torch_geometric.sampler import NodeSamplerInput, SamplerOutput + def __init__( self, graph: WholeGraphGraphStore, @@ -45,19 +43,18 @@ def __init__( row_indx, col_ptrs, _ = graph.csc() self.wg_sampler.set_csr_graph(col_ptrs._tensor, row_indx._tensor) - def sample_from_nodes( - self, - inputs: NodeSamplerInput - ) -> SamplerOutput: - r""" - Sample subgraphs from the given nodes based on uniform node-based sampling. + def sample_from_nodes(self, inputs: NodeSamplerInput) -> SamplerOutput: + r"""Sample subgraphs from the given nodes based on uniform node-based sampling. """ - seed = inputs.node.cuda(non_blocking=True) # WholeGraph Sampler needs all seeds on device - WG_SampleOutput = self.wg_sampler.multilayer_sample_without_replacement(seed, self.num_neighbors, None) + seed = inputs.node.cuda( + non_blocking=True) # WholeGraph Sampler needs all seeds on device + WG_SampleOutput = self.wg_sampler.multilayer_sample_without_replacement( + seed, self.num_neighbors, None) out = WholeGraphGraphStore.create_pyg_subgraph(WG_SampleOutput) out.metadata = (inputs.input_id, inputs.time) return out + def run(world_size, rank, local_rank, device): wall_clock_start = time.perf_counter() @@ -71,10 +68,13 @@ def run(world_size, rank, local_rank, device): # Load the dataset in the local root process and share it with others if dist_shmem.get_local_rank() == 0: # Use pre_transform to avoid on-fly graph format conversion and reduce RAM usage during runtime, especially for large graphs. - dataset = PygNodePropPredDataset(name='ogbn-papers100M', root='/workspace', pre_transform=transform) + dataset = PygNodePropPredDataset(name='ogbn-papers100M', + root='/workspace', + pre_transform=transform) else: dataset = None - dataset = dist_shmem.to_shmem(dataset) # move dataset to shmem for local ranks access + dataset = dist_shmem.to_shmem( + dataset) # move dataset to shmem for local ranks access split_idx = dataset.get_idx_split() split_idx['train'] = split_idx['train'].split( @@ -93,8 +93,8 @@ def run(world_size, rank, local_rank, device): kwargs = dict( data=(feature_store, graph_store), batch_size=1024, - num_workers=0, # with wholegraph graph store you don't need workers - filter_per_worker=False, # WholeGraph feature fetching is not fork-safe + num_workers=0, # with wholegraph graph store you don't need workers + filter_per_worker=False, # WholeGraph feature fetching is not fork-safe ) node_sampler = WholeGraphSampler( @@ -109,15 +109,17 @@ def run(world_size, rank, local_rank, device): drop_last=True, **kwargs, ) - val_loader = NodeLoader(input_nodes=split_idx['valid'], node_sampler=node_sampler, **kwargs) - test_loader = NodeLoader(input_nodes=split_idx['test'], node_sampler=node_sampler, **kwargs) + val_loader = NodeLoader(input_nodes=split_idx['valid'], + node_sampler=node_sampler, **kwargs) + test_loader = NodeLoader(input_nodes=split_idx['test'], + node_sampler=node_sampler, **kwargs) eval_steps = 1000 model = GCN(num_features, 256, 2, num_classes) acc = Accuracy(task="multiclass", num_classes=num_classes).to(device) model = DistributedDataParallel(model.to(device), device_ids=[local_rank]) optimizer = torch.optim.Adam(model.parameters(), lr=0.001, - weight_decay=5e-4) + weight_decay=5e-4) if rank == 0: prep_time = round(time.perf_counter() - wall_clock_start, 2) @@ -164,11 +166,9 @@ def test(loader: NodeLoader, num_steps: Optional[int] = None): eval_acc = test(val_loader, num_steps=eval_steps) if rank == 0: print(f"Val Accuracy: {eval_acc:.4f}%", ) - print( - f"Epoch {epoch:05d} | " - f"Accuracy {eval_acc:.4f} | " - f"Time {epoch_end - start:.2f}" - ) + print(f"Epoch {epoch:05d} | " + f"Accuracy {eval_acc:.4f} | " + f"Time {epoch_end - start:.2f}") acc.reset() dist.barrier() @@ -178,6 +178,7 @@ def test(loader: NodeLoader, num_steps: Optional[int] = None): print(f"Test Accuracy: {test_acc:.4f}%", ) dist.destroy_process_group() if dist.is_initialized() else None + if __name__ == '__main__': # Get the world size from the WORLD_SIZE variable or directly from SLURM: world_size = int( @@ -191,4 +192,3 @@ def test(loader: NodeLoader, num_steps: Optional[int] = None): device = torch.device(local_rank) torch.cuda.set_device(device) run(world_size, rank, local_rank, device) - From 1e2bd6ff0676f8dc7aca0ab8ff39f7ff400c4e2e Mon Sep 17 00:00:00 2001 From: chang-l Date: Thu, 17 Oct 2024 17:19:04 -0700 Subject: [PATCH 3/6] Minor fix for typos and comments --- .../distributed/wholegraph/feature_store.py | 17 +------------ .../distributed/wholegraph/graph_store.py | 25 ++++++++----------- 2 files changed, 11 insertions(+), 31 deletions(-) diff --git a/examples/distributed/wholegraph/feature_store.py b/examples/distributed/wholegraph/feature_store.py index c2d7ee9efd44..895adfe48258 100644 --- a/examples/distributed/wholegraph/feature_store.py +++ b/examples/distributed/wholegraph/feature_store.py @@ -42,7 +42,7 @@ def __init__(self, pyg_data): self._store = { } # A dictionary of tuple to hold the feature embeddings - if dist_shmem.get_local_rank() == dist.get_rank(): + if dist_shmem.get_local_size() == dist.get_world_size(): self.backend = 'vmm' else: self.backend = 'vmm' if nvlink_network() else 'nccl' @@ -63,17 +63,6 @@ def __init__(self, pyg_data): self.put_tensor(pyg_data[group_name][attr_name], group_name=group_name, attr_name=attr_name, index=None) - # This is a hack for MAG240M dataset, to add node features for 'institution' and 'author' nodes. - # This should not be presented in the upstream code. - elif attr_name == 'num_nodes': - feature_dim = 768 - num_nodes = group[attr_name] - shape = [num_nodes, feature_dim] - self[group_name, 'x', - None] = DistEmbedding(shape=shape, - dtype=torch.float16, - device="cpu", - backend=self.backend) else: raise TypeError( "Expected pyg_data to be of type torch_geometric.data.Data or torch_geometric.data.HeteroData." @@ -82,10 +71,6 @@ def __init__(self, pyg_data): def _put_tensor(self, tensor: torch.Tensor, attr): """Creates and stores features (either DistTensor or DistEmbedding) from the given tensor, using a key derived from the group and attribute name. - - Args: - tensor (torch.Tensor): The tensor to be passed to the feature store. - attr: PyG's TensorAttr to fully specify each feature store. """ key = (attr.group_name, attr.attr_name) out = self._store.get(key) diff --git a/examples/distributed/wholegraph/graph_store.py b/examples/distributed/wholegraph/graph_store.py index de2cbe11dba0..c1cdb74394b3 100644 --- a/examples/distributed/wholegraph/graph_store.py +++ b/examples/distributed/wholegraph/graph_store.py @@ -16,7 +16,7 @@ class WholeGraphEdgeAttr(EdgeAttr): def __init__( self, edge_type: Optional[ - EdgeType] = None, # use string to represent edge type for simplicity + EdgeType] = None, is_sorted: bool = False, size: Optional[Tuple[int, int]] = None, ): @@ -41,7 +41,7 @@ def __init__(self, pyg_data, format='wholegraph'): if format == 'wholegraph': pinned_shared = False - if dist_shmem.get_local_rank() == dist.get_rank(): + if dist_shmem.get_local_size() == dist.get_world_size(): backend = 'vmm' else: backend = 'vmm' if nvlink_network() else 'nccl' @@ -57,12 +57,11 @@ def __init__(self, pyg_data, format='wholegraph'): if 'adj_t' not in pyg_data: row, col = None, None if dist_shmem.get_local_rank() == 0: - row, col, _ = pyg_data.csc() # discard permutation for now + row, col, _ = pyg_data.csc() row = dist_shmem.to_shmem(row) col = dist_shmem.to_shmem(col) size = pyg_data.size() else: - # issue: it wont work if adj_t is a SparseTensor col = pyg_data.adj_t.crow_indices() row = pyg_data.adj_t.col_indices() size = pyg_data.adj_t.size()[::-1] @@ -87,12 +86,11 @@ def __init__(self, pyg_data, format='wholegraph'): row, col = None, None if dist_shmem.get_local_rank() == 0: row, col, _ = edge_store.csc( - ) # discard permutation for now + ) row = dist_shmem.to_shmem(row) col = dist_shmem.to_shmem(col) size = edge_store.size() else: - # issue: this will also if adj_t is a SparseTensor col = edge_store.adj_t.crow_indices() row = edge_store.adj_t.col_indices() size = edge_store.adj_t.size()[::-1] @@ -106,21 +104,18 @@ def __init__(self, pyg_data, format='wholegraph'): self.put_adj_t(graph, edge_type=edge_type, size=size) def put_adj_t(self, adj_t: DistGraphCSC, *args, **kwargs) -> bool: - r"""Synchronously adds an :obj:`edge_index` tuple to the - :class:`GraphStore`. + """Add an adj_t (adj with transpose) matrix, :obj:`DistGraphCSC` + to :class:`WholeGraphGraphStore`. Returns whether insertion was successful. - - Args: - edge_index (Tuple[torch.Tensor, torch.Tensor]): The - :obj:`edge_index` tuple in a format specified in - :class:`EdgeAttr`. - *args: Arguments passed to :class:`EdgeAttr`. - **kwargs: Keyword arguments passed to :class:`EdgeAttr`. """ edge_attr = self._edge_attr_cls.cast(*args, **kwargs) return self._put_adj_t(adj_t, edge_attr) def get_adj_t(self, *args, **kwargs) -> DistGraphCSC: + """Retrieves an adj_t (adj with transpose) matrix, :obj:`DistGraphCSC` + from :class:`WholeGraphGraphStore`. + Return: :obj:`DistGraphCSC` + """ edge_attr = self._edge_attr_cls.cast(*args, **kwargs) graph_adj_t = self._get_adj_t(edge_attr) if graph_adj_t is None: From fb432d8fb45d7bfe626897debb50dbea46907cc3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 18 Oct 2024 00:21:50 +0000 Subject: [PATCH 4/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/distributed/wholegraph/graph_store.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/examples/distributed/wholegraph/graph_store.py b/examples/distributed/wholegraph/graph_store.py index c1cdb74394b3..6bba370921cb 100644 --- a/examples/distributed/wholegraph/graph_store.py +++ b/examples/distributed/wholegraph/graph_store.py @@ -15,8 +15,7 @@ class WholeGraphEdgeAttr(EdgeAttr): r"""Edge attribute class for WholeGraph GraphStore enforcing layout to be CSC.""" def __init__( self, - edge_type: Optional[ - EdgeType] = None, + edge_type: Optional[EdgeType] = None, is_sorted: bool = False, size: Optional[Tuple[int, int]] = None, ): @@ -85,8 +84,7 @@ def __init__(self, pyg_data, format='wholegraph'): if 'adj_t' not in edge_store: row, col = None, None if dist_shmem.get_local_rank() == 0: - row, col, _ = edge_store.csc( - ) + row, col, _ = edge_store.csc() row = dist_shmem.to_shmem(row) col = dist_shmem.to_shmem(col) size = edge_store.size() From 12b604b629c28e7d7221799080678a6927b9dbda Mon Sep 17 00:00:00 2001 From: chang-l Date: Fri, 1 Nov 2024 16:46:50 -0700 Subject: [PATCH 5/6] Example reorg under NVIDIA RAPIDS folder --- examples/distributed/NVIDIA-RAPIDS/README.md | 8 ++++++++ .../cugraph}/papers100m_gcn_cugraph_multinode.py | 0 .../distributed/{ => NVIDIA-RAPIDS}/wholegraph/README | 0 .../{ => NVIDIA-RAPIDS}/wholegraph/benchmark_data.py | 0 .../{ => NVIDIA-RAPIDS}/wholegraph/feature_store.py | 0 .../{ => NVIDIA-RAPIDS}/wholegraph/graph_store.py | 0 .../wholegraph/nv_distributed_graph/__init__.py | 0 .../wholegraph/nv_distributed_graph/dist_graph.py | 0 .../wholegraph/nv_distributed_graph/dist_shmem.py | 0 .../wholegraph/nv_distributed_graph/dist_tensor.py | 0 .../wholegraph/nv_distributed_graph/wholegraph.py | 0 .../wholegraph/papers100m_dist_wholegraph_nc.py | 0 12 files changed, 8 insertions(+) create mode 100644 examples/distributed/NVIDIA-RAPIDS/README.md rename examples/{multi_gpu => distributed/NVIDIA-RAPIDS/cugraph}/papers100m_gcn_cugraph_multinode.py (100%) rename examples/distributed/{ => NVIDIA-RAPIDS}/wholegraph/README (100%) rename examples/distributed/{ => NVIDIA-RAPIDS}/wholegraph/benchmark_data.py (100%) rename examples/distributed/{ => NVIDIA-RAPIDS}/wholegraph/feature_store.py (100%) rename examples/distributed/{ => NVIDIA-RAPIDS}/wholegraph/graph_store.py (100%) rename examples/distributed/{ => NVIDIA-RAPIDS}/wholegraph/nv_distributed_graph/__init__.py (100%) rename examples/distributed/{ => NVIDIA-RAPIDS}/wholegraph/nv_distributed_graph/dist_graph.py (100%) rename examples/distributed/{ => NVIDIA-RAPIDS}/wholegraph/nv_distributed_graph/dist_shmem.py (100%) rename examples/distributed/{ => NVIDIA-RAPIDS}/wholegraph/nv_distributed_graph/dist_tensor.py (100%) rename examples/distributed/{ => NVIDIA-RAPIDS}/wholegraph/nv_distributed_graph/wholegraph.py (100%) rename examples/distributed/{ => NVIDIA-RAPIDS}/wholegraph/papers100m_dist_wholegraph_nc.py (100%) diff --git a/examples/distributed/NVIDIA-RAPIDS/README.md b/examples/distributed/NVIDIA-RAPIDS/README.md new file mode 100644 index 000000000000..10da0a53aca1 --- /dev/null +++ b/examples/distributed/NVIDIA-RAPIDS/README.md @@ -0,0 +1,8 @@ +# Distributed Training with PyG using NVIDIA RAPIDS libraries + +This directory contains examples for distributed graph learning using NVIDIA RAPIDS cuGraph/WholeGraph libraries. These examples minimize CPU interruptions and maximize GPU throughput advantages during the graph dataloading stage. In our tests, we normally observe at least over a tenfold speedup compared to traditional CPU-based [RPC methods](../pyg). Additionally, the libraries are user-friendly, enabling flexible integration with minimal effort to upgrade from users' GNN training workflows. + +Currently, we offer two integration options for NVIDIA RAPIDS support: the first is through cuGraph, which provides a higher-level API (cuGraph dataloader), and the second is through WholeGraph, leveraging PyG remote backend APIs for better flexibility to accelerate GNN training and various GraphML tasks. We plan to merge these two paths soon under [cugraph-gnn](https://github.com/rapidsai/cugraph-gnn), creating a unified, multi-level APIs to simplify the user learning curve. + +1. [`cuGraph`](./cugraph): Distributed training via NVIDIA RAPIDS [cuGraph](https://github.com/rapidsai/cugraph) library. +2. [`WholeGraph`](./wholegraph): Distributed training via PyG remote backend APIs and NVIDIA RAPIDS [WholeGraph](https://github.com/rapidsai/wholegraph) library. \ No newline at end of file diff --git a/examples/multi_gpu/papers100m_gcn_cugraph_multinode.py b/examples/distributed/NVIDIA-RAPIDS/cugraph/papers100m_gcn_cugraph_multinode.py similarity index 100% rename from examples/multi_gpu/papers100m_gcn_cugraph_multinode.py rename to examples/distributed/NVIDIA-RAPIDS/cugraph/papers100m_gcn_cugraph_multinode.py diff --git a/examples/distributed/wholegraph/README b/examples/distributed/NVIDIA-RAPIDS/wholegraph/README similarity index 100% rename from examples/distributed/wholegraph/README rename to examples/distributed/NVIDIA-RAPIDS/wholegraph/README diff --git a/examples/distributed/wholegraph/benchmark_data.py b/examples/distributed/NVIDIA-RAPIDS/wholegraph/benchmark_data.py similarity index 100% rename from examples/distributed/wholegraph/benchmark_data.py rename to examples/distributed/NVIDIA-RAPIDS/wholegraph/benchmark_data.py diff --git a/examples/distributed/wholegraph/feature_store.py b/examples/distributed/NVIDIA-RAPIDS/wholegraph/feature_store.py similarity index 100% rename from examples/distributed/wholegraph/feature_store.py rename to examples/distributed/NVIDIA-RAPIDS/wholegraph/feature_store.py diff --git a/examples/distributed/wholegraph/graph_store.py b/examples/distributed/NVIDIA-RAPIDS/wholegraph/graph_store.py similarity index 100% rename from examples/distributed/wholegraph/graph_store.py rename to examples/distributed/NVIDIA-RAPIDS/wholegraph/graph_store.py diff --git a/examples/distributed/wholegraph/nv_distributed_graph/__init__.py b/examples/distributed/NVIDIA-RAPIDS/wholegraph/nv_distributed_graph/__init__.py similarity index 100% rename from examples/distributed/wholegraph/nv_distributed_graph/__init__.py rename to examples/distributed/NVIDIA-RAPIDS/wholegraph/nv_distributed_graph/__init__.py diff --git a/examples/distributed/wholegraph/nv_distributed_graph/dist_graph.py b/examples/distributed/NVIDIA-RAPIDS/wholegraph/nv_distributed_graph/dist_graph.py similarity index 100% rename from examples/distributed/wholegraph/nv_distributed_graph/dist_graph.py rename to examples/distributed/NVIDIA-RAPIDS/wholegraph/nv_distributed_graph/dist_graph.py diff --git a/examples/distributed/wholegraph/nv_distributed_graph/dist_shmem.py b/examples/distributed/NVIDIA-RAPIDS/wholegraph/nv_distributed_graph/dist_shmem.py similarity index 100% rename from examples/distributed/wholegraph/nv_distributed_graph/dist_shmem.py rename to examples/distributed/NVIDIA-RAPIDS/wholegraph/nv_distributed_graph/dist_shmem.py diff --git a/examples/distributed/wholegraph/nv_distributed_graph/dist_tensor.py b/examples/distributed/NVIDIA-RAPIDS/wholegraph/nv_distributed_graph/dist_tensor.py similarity index 100% rename from examples/distributed/wholegraph/nv_distributed_graph/dist_tensor.py rename to examples/distributed/NVIDIA-RAPIDS/wholegraph/nv_distributed_graph/dist_tensor.py diff --git a/examples/distributed/wholegraph/nv_distributed_graph/wholegraph.py b/examples/distributed/NVIDIA-RAPIDS/wholegraph/nv_distributed_graph/wholegraph.py similarity index 100% rename from examples/distributed/wholegraph/nv_distributed_graph/wholegraph.py rename to examples/distributed/NVIDIA-RAPIDS/wholegraph/nv_distributed_graph/wholegraph.py diff --git a/examples/distributed/wholegraph/papers100m_dist_wholegraph_nc.py b/examples/distributed/NVIDIA-RAPIDS/wholegraph/papers100m_dist_wholegraph_nc.py similarity index 100% rename from examples/distributed/wholegraph/papers100m_dist_wholegraph_nc.py rename to examples/distributed/NVIDIA-RAPIDS/wholegraph/papers100m_dist_wholegraph_nc.py From 7193592c6ffe242c12cc0237844fb595473fe43c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 1 Nov 2024 23:48:46 +0000 Subject: [PATCH 6/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/distributed/NVIDIA-RAPIDS/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/distributed/NVIDIA-RAPIDS/README.md b/examples/distributed/NVIDIA-RAPIDS/README.md index 10da0a53aca1..92682efb2ff0 100644 --- a/examples/distributed/NVIDIA-RAPIDS/README.md +++ b/examples/distributed/NVIDIA-RAPIDS/README.md @@ -5,4 +5,4 @@ This directory contains examples for distributed graph learning using NVIDIA RAP Currently, we offer two integration options for NVIDIA RAPIDS support: the first is through cuGraph, which provides a higher-level API (cuGraph dataloader), and the second is through WholeGraph, leveraging PyG remote backend APIs for better flexibility to accelerate GNN training and various GraphML tasks. We plan to merge these two paths soon under [cugraph-gnn](https://github.com/rapidsai/cugraph-gnn), creating a unified, multi-level APIs to simplify the user learning curve. 1. [`cuGraph`](./cugraph): Distributed training via NVIDIA RAPIDS [cuGraph](https://github.com/rapidsai/cugraph) library. -2. [`WholeGraph`](./wholegraph): Distributed training via PyG remote backend APIs and NVIDIA RAPIDS [WholeGraph](https://github.com/rapidsai/wholegraph) library. \ No newline at end of file +1. [`WholeGraph`](./wholegraph): Distributed training via PyG remote backend APIs and NVIDIA RAPIDS [WholeGraph](https://github.com/rapidsai/wholegraph) library.