Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Example] Add WholeGraph to accelerate PyG dataloaders with GPUs #9714

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions examples/distributed/NVIDIA-RAPIDS/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Distributed Training with PyG using NVIDIA RAPIDS libraries

This directory contains examples for distributed graph learning using NVIDIA RAPIDS cuGraph/WholeGraph libraries. These examples minimize CPU interruptions and maximize GPU throughput advantages during the graph dataloading stage. In our tests, we normally observe at least over a tenfold speedup compared to traditional CPU-based [RPC methods](../pyg). Additionally, the libraries are user-friendly, enabling flexible integration with minimal effort to upgrade from users' GNN training workflows.

Currently, we offer two integration options for NVIDIA RAPIDS support: the first is through cuGraph, which provides a higher-level API (cuGraph dataloader), and the second is through WholeGraph, leveraging PyG remote backend APIs for better flexibility to accelerate GNN training and various GraphML tasks. We plan to merge these two paths soon under [cugraph-gnn](https://github.com/rapidsai/cugraph-gnn), creating a unified, multi-level APIs to simplify the user learning curve.

1. [`cuGraph`](./cugraph): Distributed training via NVIDIA RAPIDS [cuGraph](https://github.com/rapidsai/cugraph) library.
1. [`WholeGraph`](./wholegraph): Distributed training via PyG remote backend APIs and NVIDIA RAPIDS [WholeGraph](https://github.com/rapidsai/wholegraph) library.
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alexbarghi-nv do you mind add a README file under this directory later?

File renamed without changes.
62 changes: 62 additions & 0 deletions examples/distributed/NVIDIA-RAPIDS/wholegraph/README
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Using NVIDIA WholeGraph Library for Distributed Training with PyG

**[RAPIDS WholeGraph](https://github.com/rapidsai/wholegraph)**
NVIDIA WholeGraph is designed to optimize the training of Graph Neural Networks (GNNs) that are often constrained by data loading operations. It provides an underlying storage structure, called WholeMemory, which efficiently manages data storage/communication across disk, RAM, and device memory by leveraging NVIDIA GPUs and communication libraries like NCCL/NVSHMEM.

WholeGraph is a low-level graph storage library, integrated into and able to work alongside cuGraph, that directly provides an efficient feature and graph store with associated primitive operations (e.g., GPU-accelerated fast embedding retrieval and graph sampling). It is specifically optimized for NVLink systems, including DGX, MGX, and GH/GB200 machine or clusters.

This example demonstrates how to use WholeGraph to easily distribute the graph and feature store to pinned-host memory for fast GPU UVA access (see the DistTensor class), eliminating the need for manual graph partitioning or any custom third-party launch scripts. WholeGraph seamlessly integrates with PyTorch's Distributed Data Parallel (DDP) setup and works with standard distributed job launchers such as torchrun, mpirun, or srun.

## Requirements

- **PyTorch**: `>= 2.0`
- **PyTorch Geometric**: `>= 2.0.0`
- **WholeGraph**: `>= 24.02`
- **NVIDIA GPU(s)**

## Environment Setup

```bash
pip install pylibwholegraph-cu12
```

## Sinlge/Multi-GPU Run

Using PyTorch torchrun elastic launcher:
```
torchrun papers100m_dist_wholegraph_nc.py
```
or, using multi-GPUs if applicable:
```
torchrun --nnodes 1 --nproc-per-node <ngpu_per_node> papers100m_dist_wholegraph_nc.py
```

## Distributed (multi-node) Run

For example, let's use the slurm launcher here:

```
srun -N<num_nodes> --ntasks-per-node=<ngpu_per_node> python papers100m_dist_wholegraph_nc.py
```

Note the above command line setting is simplified for demonstration purposes. For more details, please refer to this [sbatch script](https://github.com/chang-l/pytorch_geometric/blob/master/examples/multi_gpu/distributed_sampling_multinode.sbatch), as cluster setups may vary.


## Benchmark Run

The benchmark script is similar to the above example but includes a `--mode` command-line argument, allowing users to easily compare PyG's native features/graph store (`torch_geometric.data.Data` and `torch_geometric.data.HeteroData`) with the WholeMemory-based feature store and graph store, shown in this example. It performs a node classification task on the `ogbn-products` dataset.

### PyG baseline
```
torchrun --nnodes 1 --nproc-per-node <ngpu_per_node> benchmark_data.py --mode baseline
```

### WholeGraph FeatureStore integration (UVA for feature store access)
```
torchrun --nnodes 1 --nproc-per-node <ngpu_per_node> benchmark_data.py --mode UVA-features
```

### WholeGraph FeatureStore + GraphStore (UVA for feature and graph store access)
```
torchrun --nnodes 1 --nproc-per-node <ngpu_per_node> benchmark_data.py --mode UVA
```
231 changes: 231 additions & 0 deletions examples/distributed/NVIDIA-RAPIDS/wholegraph/benchmark_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
"""Multi-node multi-GPU example on ogbn-papers100m.

Example way to run using srun:
srun -l -N<num_nodes> --ntasks-per-node=<ngpu_per_node> \
--container-name=cont --container-image=<image_url> \
--container-mounts=/ogb-papers100m/:/workspace/dataset
python3 path_to_script.py
"""
import argparse
import os
import time
from typing import Optional

import torch
import torch.distributed as dist
import torch.nn.functional as F
from feature_store import WholeGraphFeatureStore
from graph_store import WholeGraphGraphStore
from nv_distributed_graph import dist_shmem
from ogb.nodeproppred import PygNodePropPredDataset
from torch.nn.parallel import DistributedDataParallel
from torchmetrics import Accuracy

from torch_geometric.loader import NeighborLoader, NodeLoader
from torch_geometric.nn import GCN
from torch_geometric.sampler import BaseSampler


class WholeGraphSampler(BaseSampler):
r"""A naive sampler class for WholeGraph graph storage that only supports uniform node-based sampling on homogeneous graph.
"""
from torch_geometric.sampler import NodeSamplerInput, SamplerOutput

def __init__(
self,
graph: WholeGraphGraphStore,
num_neighbors,
):
import pylibwholegraph.torch as wgth

self.num_neighbors = num_neighbors
self.wg_sampler = wgth.GraphStructure()
row_indx, col_ptrs, _ = graph.csc()
self.wg_sampler.set_csr_graph(col_ptrs._tensor, row_indx._tensor)

def sample_from_nodes(self, inputs: NodeSamplerInput) -> SamplerOutput:
r"""Sample subgraphs from the given nodes based on uniform node-based sampling.
"""
seed = inputs.node.cuda(
non_blocking=True) # WholeGraph Sampler needs all seeds on device
WG_SampleOutput = self.wg_sampler.multilayer_sample_without_replacement(
seed, self.num_neighbors, None)
out = WholeGraphGraphStore.create_pyg_subgraph(WG_SampleOutput)
out.metadata = (inputs.input_id, inputs.time)
return out


def run(world_size, rank, local_rank, device, mode):
wall_clock_start = time.perf_counter()

# Will query the runtime environment for `MASTER_ADDR` and `MASTER_PORT`.
# Make sure, those are set!
dist.init_process_group('nccl', world_size=world_size, rank=rank)
dist_shmem.init_process_group_per_node()

# Load the dataset in the local root process and share it with local ranks
if dist_shmem.get_local_rank() == 0:
dataset = PygNodePropPredDataset(name='ogbn-products',
root='/workspace')
else:
dataset = None
dataset = dist_shmem.to_shmem(dataset) # move dataset to shmem

split_idx = dataset.get_idx_split()
split_idx['train'] = split_idx['train'].split(
split_idx['train'].size(0) // world_size, dim=0)[rank].clone()
split_idx['valid'] = split_idx['valid'].split(
split_idx['valid'].size(0) // world_size, dim=0)[rank].clone()
split_idx['test'] = split_idx['test'].split(
split_idx['test'].size(0) // world_size, dim=0)[rank].clone()
data = dataset[0]
num_features = dataset.num_features
num_classes = dataset.num_classes

if mode == 'baseline':
data = data
kwargs = dict(
data=data,
batch_size=1024,
num_neighbors=[30, 30],
num_workers=4,
)
train_loader = NeighborLoader(
input_nodes=split_idx['train'],
shuffle=True,
drop_last=True,
**kwargs,
)
val_loader = NeighborLoader(input_nodes=split_idx['valid'], **kwargs)
test_loader = NeighborLoader(input_nodes=split_idx['test'], **kwargs)

elif mode == 'UVA-features':
feature_store = WholeGraphFeatureStore(pyg_data=data)
graph_store = WholeGraphGraphStore(pyg_data=data, format='pyg')
data = (feature_store, graph_store)
kwargs = dict(
data=data,
batch_size=1024,
num_neighbors=[30, 30],
num_workers=4,
filter_per_worker=
False, # WholeGraph feature fetching is not fork-safe
)
train_loader = NeighborLoader(
input_nodes=split_idx['train'],
shuffle=True,
drop_last=True,
**kwargs,
)
val_loader = NeighborLoader(input_nodes=split_idx['valid'], **kwargs)
test_loader = NeighborLoader(input_nodes=split_idx['test'], **kwargs)

elif mode == 'UVA':
feature_store = WholeGraphFeatureStore(pyg_data=data)
graph_store = WholeGraphGraphStore(pyg_data=data)
data = (feature_store, graph_store)
kwargs = dict(
data=data,
batch_size=1024,
num_workers=0, # with wholegraph sampler you don't need workers
filter_per_worker=
False, # WholeGraph feature fetching is not fork-safe
)
node_sampler = WholeGraphSampler(
graph_store,
num_neighbors=[30, 30],
)
train_loader = NodeLoader(
input_nodes=split_idx['train'],
node_sampler=node_sampler,
shuffle=True,
drop_last=True,
**kwargs,
)
val_loader = NodeLoader(input_nodes=split_idx['valid'],
node_sampler=node_sampler, **kwargs)
test_loader = NodeLoader(input_nodes=split_idx['test'],
node_sampler=node_sampler, **kwargs)

eval_steps = 1000
model = GCN(num_features, 256, 2, num_classes)
acc = Accuracy(task="multiclass", num_classes=num_classes).to(device)
model = DistributedDataParallel(model.to(device), device_ids=[local_rank])
optimizer = torch.optim.Adam(model.parameters(), lr=0.001,
weight_decay=5e-4)

if rank == 0:
prep_time = round(time.perf_counter() - wall_clock_start, 2)
print("Total time before training begins (prep_time)=", prep_time,
"seconds")
print("Beginning training...")

for epoch in range(1, 21):
dist.barrier()
start = time.time()
model.train()
for i, batch in enumerate(train_loader):
batch = batch.to(device)
optimizer.zero_grad()
y = batch.y[:batch.batch_size].view(-1).to(torch.long)
out = model(batch.x, batch.edge_index)[:batch.batch_size]
loss = F.cross_entropy(out, y)
loss.backward()
optimizer.step()
if rank == 0 and i % 100 == 0:
print(f'Epoch: {epoch:02d}, Iteration: {i}, Loss: {loss:.4f}')

# Profile run:
# We synchronize before barrier to flush GPU OPs first,
# then adding barrier to sync CPUs to find max train time among all ranks.
torch.cuda.synchronize()
dist.barrier()
epoch_end = time.time()

@torch.no_grad()
def test(loader: NodeLoader, num_steps: Optional[int] = None):
model.eval()
for j, batch in enumerate(loader):
if num_steps is not None and j >= num_steps:
break
batch = batch.to(device)
out = model(batch.x, batch.edge_index)[:batch.batch_size]
y = batch.y[:batch.batch_size].view(-1).to(torch.long)
acc(out, y)
acc_sum = acc.compute()
return acc_sum

eval_acc = test(val_loader, num_steps=eval_steps)
if rank == 0:
print(f"Val Accuracy: {eval_acc:.4f}%", )
print(f"Epoch {epoch:05d} | "
f"Accuracy {eval_acc:.4f} | "
f"Time {epoch_end - start:.2f}")

acc.reset()
dist.barrier()

test_acc = test(test_loader)
if rank == 0:
print(f"Test Accuracy: {test_acc:.4f}%", )
dist.destroy_process_group() if dist.is_initialized() else None


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--mode', type=str, default='baseline',
choices=['baseline', 'UVA-features', 'UVA'])
args = parser.parse_args()

# Get the world size from the WORLD_SIZE variable or directly from SLURM:
world_size = int(
os.environ.get('WORLD_SIZE', os.environ.get('SLURM_NTASKS')))
# Likewise for RANK and LOCAL_RANK:
rank = int(os.environ.get('RANK', os.environ.get('SLURM_PROCID')))
local_rank = int(
os.environ.get('LOCAL_RANK', os.environ.get('SLURM_LOCALID')))

assert torch.cuda.is_available()
device = torch.device(local_rank)
torch.cuda.set_device(device)
run(world_size, rank, local_rank, device, args.mode)
Loading
Loading