-
Notifications
You must be signed in to change notification settings - Fork 3.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Example] Add WholeGraph to accelerate PyG dataloaders with GPUs #9714
Open
chang-l
wants to merge
8
commits into
pyg-team:master
Choose a base branch
from
chang-l:add-uva-ddp-pyg
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+1,790
−0
Open
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
1240df9
Add example
chang-l 9f170b0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 1e2bd6f
Minor fix for typos and comments
chang-l fb432d8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] f86e75b
Merge branch 'master' into add-uva-ddp-pyg
puririshi98 9ebbd19
Merge branch 'master' into add-uva-ddp-pyg
puririshi98 12b604b
Example reorg under NVIDIA RAPIDS folder
chang-l 7193592
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# Distributed Training with PyG using NVIDIA RAPIDS libraries | ||
|
||
This directory contains examples for distributed graph learning using NVIDIA RAPIDS cuGraph/WholeGraph libraries. These examples minimize CPU interruptions and maximize GPU throughput advantages during the graph dataloading stage. In our tests, we normally observe at least over a tenfold speedup compared to traditional CPU-based [RPC methods](../pyg). Additionally, the libraries are user-friendly, enabling flexible integration with minimal effort to upgrade from users' GNN training workflows. | ||
|
||
Currently, we offer two integration options for NVIDIA RAPIDS support: the first is through cuGraph, which provides a higher-level API (cuGraph dataloader), and the second is through WholeGraph, leveraging PyG remote backend APIs for better flexibility to accelerate GNN training and various GraphML tasks. We plan to merge these two paths soon under [cugraph-gnn](https://github.com/rapidsai/cugraph-gnn), creating a unified, multi-level APIs to simplify the user learning curve. | ||
|
||
1. [`cuGraph`](./cugraph): Distributed training via NVIDIA RAPIDS [cuGraph](https://github.com/rapidsai/cugraph) library. | ||
1. [`WholeGraph`](./wholegraph): Distributed training via PyG remote backend APIs and NVIDIA RAPIDS [WholeGraph](https://github.com/rapidsai/wholegraph) library. |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# Using NVIDIA WholeGraph Library for Distributed Training with PyG | ||
|
||
**[RAPIDS WholeGraph](https://github.com/rapidsai/wholegraph)** | ||
NVIDIA WholeGraph is designed to optimize the training of Graph Neural Networks (GNNs) that are often constrained by data loading operations. It provides an underlying storage structure, called WholeMemory, which efficiently manages data storage/communication across disk, RAM, and device memory by leveraging NVIDIA GPUs and communication libraries like NCCL/NVSHMEM. | ||
|
||
WholeGraph is a low-level graph storage library, integrated into and able to work alongside cuGraph, that directly provides an efficient feature and graph store with associated primitive operations (e.g., GPU-accelerated fast embedding retrieval and graph sampling). It is specifically optimized for NVLink systems, including DGX, MGX, and GH/GB200 machine or clusters. | ||
|
||
This example demonstrates how to use WholeGraph to easily distribute the graph and feature store to pinned-host memory for fast GPU UVA access (see the DistTensor class), eliminating the need for manual graph partitioning or any custom third-party launch scripts. WholeGraph seamlessly integrates with PyTorch's Distributed Data Parallel (DDP) setup and works with standard distributed job launchers such as torchrun, mpirun, or srun. | ||
|
||
## Requirements | ||
|
||
- **PyTorch**: `>= 2.0` | ||
- **PyTorch Geometric**: `>= 2.0.0` | ||
- **WholeGraph**: `>= 24.02` | ||
- **NVIDIA GPU(s)** | ||
|
||
## Environment Setup | ||
|
||
```bash | ||
pip install pylibwholegraph-cu12 | ||
``` | ||
|
||
## Sinlge/Multi-GPU Run | ||
|
||
Using PyTorch torchrun elastic launcher: | ||
``` | ||
torchrun papers100m_dist_wholegraph_nc.py | ||
``` | ||
or, using multi-GPUs if applicable: | ||
``` | ||
torchrun --nnodes 1 --nproc-per-node <ngpu_per_node> papers100m_dist_wholegraph_nc.py | ||
``` | ||
|
||
## Distributed (multi-node) Run | ||
|
||
For example, let's use the slurm launcher here: | ||
|
||
``` | ||
srun -N<num_nodes> --ntasks-per-node=<ngpu_per_node> python papers100m_dist_wholegraph_nc.py | ||
``` | ||
|
||
Note the above command line setting is simplified for demonstration purposes. For more details, please refer to this [sbatch script](https://github.com/chang-l/pytorch_geometric/blob/master/examples/multi_gpu/distributed_sampling_multinode.sbatch), as cluster setups may vary. | ||
|
||
|
||
## Benchmark Run | ||
|
||
The benchmark script is similar to the above example but includes a `--mode` command-line argument, allowing users to easily compare PyG's native features/graph store (`torch_geometric.data.Data` and `torch_geometric.data.HeteroData`) with the WholeMemory-based feature store and graph store, shown in this example. It performs a node classification task on the `ogbn-products` dataset. | ||
|
||
### PyG baseline | ||
``` | ||
torchrun --nnodes 1 --nproc-per-node <ngpu_per_node> benchmark_data.py --mode baseline | ||
``` | ||
|
||
### WholeGraph FeatureStore integration (UVA for feature store access) | ||
``` | ||
torchrun --nnodes 1 --nproc-per-node <ngpu_per_node> benchmark_data.py --mode UVA-features | ||
``` | ||
|
||
### WholeGraph FeatureStore + GraphStore (UVA for feature and graph store access) | ||
``` | ||
torchrun --nnodes 1 --nproc-per-node <ngpu_per_node> benchmark_data.py --mode UVA | ||
``` |
231 changes: 231 additions & 0 deletions
231
examples/distributed/NVIDIA-RAPIDS/wholegraph/benchmark_data.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,231 @@ | ||
"""Multi-node multi-GPU example on ogbn-papers100m. | ||
|
||
Example way to run using srun: | ||
srun -l -N<num_nodes> --ntasks-per-node=<ngpu_per_node> \ | ||
--container-name=cont --container-image=<image_url> \ | ||
--container-mounts=/ogb-papers100m/:/workspace/dataset | ||
python3 path_to_script.py | ||
""" | ||
import argparse | ||
import os | ||
import time | ||
from typing import Optional | ||
|
||
import torch | ||
import torch.distributed as dist | ||
import torch.nn.functional as F | ||
from feature_store import WholeGraphFeatureStore | ||
from graph_store import WholeGraphGraphStore | ||
from nv_distributed_graph import dist_shmem | ||
from ogb.nodeproppred import PygNodePropPredDataset | ||
from torch.nn.parallel import DistributedDataParallel | ||
from torchmetrics import Accuracy | ||
|
||
from torch_geometric.loader import NeighborLoader, NodeLoader | ||
from torch_geometric.nn import GCN | ||
from torch_geometric.sampler import BaseSampler | ||
|
||
|
||
class WholeGraphSampler(BaseSampler): | ||
r"""A naive sampler class for WholeGraph graph storage that only supports uniform node-based sampling on homogeneous graph. | ||
""" | ||
from torch_geometric.sampler import NodeSamplerInput, SamplerOutput | ||
|
||
def __init__( | ||
self, | ||
graph: WholeGraphGraphStore, | ||
num_neighbors, | ||
): | ||
import pylibwholegraph.torch as wgth | ||
|
||
self.num_neighbors = num_neighbors | ||
self.wg_sampler = wgth.GraphStructure() | ||
row_indx, col_ptrs, _ = graph.csc() | ||
self.wg_sampler.set_csr_graph(col_ptrs._tensor, row_indx._tensor) | ||
|
||
def sample_from_nodes(self, inputs: NodeSamplerInput) -> SamplerOutput: | ||
r"""Sample subgraphs from the given nodes based on uniform node-based sampling. | ||
""" | ||
seed = inputs.node.cuda( | ||
non_blocking=True) # WholeGraph Sampler needs all seeds on device | ||
WG_SampleOutput = self.wg_sampler.multilayer_sample_without_replacement( | ||
seed, self.num_neighbors, None) | ||
out = WholeGraphGraphStore.create_pyg_subgraph(WG_SampleOutput) | ||
out.metadata = (inputs.input_id, inputs.time) | ||
return out | ||
|
||
|
||
def run(world_size, rank, local_rank, device, mode): | ||
wall_clock_start = time.perf_counter() | ||
|
||
# Will query the runtime environment for `MASTER_ADDR` and `MASTER_PORT`. | ||
# Make sure, those are set! | ||
dist.init_process_group('nccl', world_size=world_size, rank=rank) | ||
dist_shmem.init_process_group_per_node() | ||
|
||
# Load the dataset in the local root process and share it with local ranks | ||
if dist_shmem.get_local_rank() == 0: | ||
dataset = PygNodePropPredDataset(name='ogbn-products', | ||
root='/workspace') | ||
else: | ||
dataset = None | ||
dataset = dist_shmem.to_shmem(dataset) # move dataset to shmem | ||
|
||
split_idx = dataset.get_idx_split() | ||
split_idx['train'] = split_idx['train'].split( | ||
split_idx['train'].size(0) // world_size, dim=0)[rank].clone() | ||
split_idx['valid'] = split_idx['valid'].split( | ||
split_idx['valid'].size(0) // world_size, dim=0)[rank].clone() | ||
split_idx['test'] = split_idx['test'].split( | ||
split_idx['test'].size(0) // world_size, dim=0)[rank].clone() | ||
data = dataset[0] | ||
num_features = dataset.num_features | ||
num_classes = dataset.num_classes | ||
|
||
if mode == 'baseline': | ||
data = data | ||
kwargs = dict( | ||
data=data, | ||
batch_size=1024, | ||
num_neighbors=[30, 30], | ||
num_workers=4, | ||
) | ||
train_loader = NeighborLoader( | ||
input_nodes=split_idx['train'], | ||
shuffle=True, | ||
drop_last=True, | ||
**kwargs, | ||
) | ||
val_loader = NeighborLoader(input_nodes=split_idx['valid'], **kwargs) | ||
test_loader = NeighborLoader(input_nodes=split_idx['test'], **kwargs) | ||
|
||
elif mode == 'UVA-features': | ||
feature_store = WholeGraphFeatureStore(pyg_data=data) | ||
graph_store = WholeGraphGraphStore(pyg_data=data, format='pyg') | ||
data = (feature_store, graph_store) | ||
kwargs = dict( | ||
data=data, | ||
batch_size=1024, | ||
num_neighbors=[30, 30], | ||
num_workers=4, | ||
filter_per_worker= | ||
False, # WholeGraph feature fetching is not fork-safe | ||
) | ||
train_loader = NeighborLoader( | ||
input_nodes=split_idx['train'], | ||
shuffle=True, | ||
drop_last=True, | ||
**kwargs, | ||
) | ||
val_loader = NeighborLoader(input_nodes=split_idx['valid'], **kwargs) | ||
test_loader = NeighborLoader(input_nodes=split_idx['test'], **kwargs) | ||
|
||
elif mode == 'UVA': | ||
feature_store = WholeGraphFeatureStore(pyg_data=data) | ||
graph_store = WholeGraphGraphStore(pyg_data=data) | ||
data = (feature_store, graph_store) | ||
kwargs = dict( | ||
data=data, | ||
batch_size=1024, | ||
num_workers=0, # with wholegraph sampler you don't need workers | ||
filter_per_worker= | ||
False, # WholeGraph feature fetching is not fork-safe | ||
) | ||
node_sampler = WholeGraphSampler( | ||
graph_store, | ||
num_neighbors=[30, 30], | ||
) | ||
train_loader = NodeLoader( | ||
input_nodes=split_idx['train'], | ||
node_sampler=node_sampler, | ||
shuffle=True, | ||
drop_last=True, | ||
**kwargs, | ||
) | ||
val_loader = NodeLoader(input_nodes=split_idx['valid'], | ||
node_sampler=node_sampler, **kwargs) | ||
test_loader = NodeLoader(input_nodes=split_idx['test'], | ||
node_sampler=node_sampler, **kwargs) | ||
|
||
eval_steps = 1000 | ||
model = GCN(num_features, 256, 2, num_classes) | ||
acc = Accuracy(task="multiclass", num_classes=num_classes).to(device) | ||
model = DistributedDataParallel(model.to(device), device_ids=[local_rank]) | ||
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, | ||
weight_decay=5e-4) | ||
|
||
if rank == 0: | ||
prep_time = round(time.perf_counter() - wall_clock_start, 2) | ||
print("Total time before training begins (prep_time)=", prep_time, | ||
"seconds") | ||
print("Beginning training...") | ||
|
||
for epoch in range(1, 21): | ||
dist.barrier() | ||
start = time.time() | ||
model.train() | ||
for i, batch in enumerate(train_loader): | ||
batch = batch.to(device) | ||
optimizer.zero_grad() | ||
y = batch.y[:batch.batch_size].view(-1).to(torch.long) | ||
out = model(batch.x, batch.edge_index)[:batch.batch_size] | ||
loss = F.cross_entropy(out, y) | ||
loss.backward() | ||
optimizer.step() | ||
if rank == 0 and i % 100 == 0: | ||
print(f'Epoch: {epoch:02d}, Iteration: {i}, Loss: {loss:.4f}') | ||
|
||
# Profile run: | ||
# We synchronize before barrier to flush GPU OPs first, | ||
# then adding barrier to sync CPUs to find max train time among all ranks. | ||
torch.cuda.synchronize() | ||
dist.barrier() | ||
epoch_end = time.time() | ||
|
||
@torch.no_grad() | ||
def test(loader: NodeLoader, num_steps: Optional[int] = None): | ||
model.eval() | ||
for j, batch in enumerate(loader): | ||
if num_steps is not None and j >= num_steps: | ||
break | ||
batch = batch.to(device) | ||
out = model(batch.x, batch.edge_index)[:batch.batch_size] | ||
y = batch.y[:batch.batch_size].view(-1).to(torch.long) | ||
acc(out, y) | ||
acc_sum = acc.compute() | ||
return acc_sum | ||
|
||
eval_acc = test(val_loader, num_steps=eval_steps) | ||
if rank == 0: | ||
print(f"Val Accuracy: {eval_acc:.4f}%", ) | ||
print(f"Epoch {epoch:05d} | " | ||
f"Accuracy {eval_acc:.4f} | " | ||
f"Time {epoch_end - start:.2f}") | ||
|
||
acc.reset() | ||
dist.barrier() | ||
|
||
test_acc = test(test_loader) | ||
if rank == 0: | ||
print(f"Test Accuracy: {test_acc:.4f}%", ) | ||
dist.destroy_process_group() if dist.is_initialized() else None | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--mode', type=str, default='baseline', | ||
choices=['baseline', 'UVA-features', 'UVA']) | ||
args = parser.parse_args() | ||
|
||
# Get the world size from the WORLD_SIZE variable or directly from SLURM: | ||
world_size = int( | ||
os.environ.get('WORLD_SIZE', os.environ.get('SLURM_NTASKS'))) | ||
# Likewise for RANK and LOCAL_RANK: | ||
rank = int(os.environ.get('RANK', os.environ.get('SLURM_PROCID'))) | ||
local_rank = int( | ||
os.environ.get('LOCAL_RANK', os.environ.get('SLURM_LOCALID'))) | ||
|
||
assert torch.cuda.is_available() | ||
device = torch.device(local_rank) | ||
torch.cuda.set_device(device) | ||
run(world_size, rank, local_rank, device, args.mode) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@alexbarghi-nv do you mind add a README file under this directory later?