Skip to content

Commit

Permalink
replace distributed sampleing
Browse files Browse the repository at this point in the history
  • Loading branch information
LukeLIN-web committed Dec 28, 2022
1 parent fe6150b commit 260a1c2
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 161 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.3.0] - 2023-MM-DD
### Added
- Added `distributed_sampling_loader.py` as an example of DDP NeighborLoader ([#6204](https://github.com/pyg-team/pytorch_geometric/pull/6307))
- Added `PGExplainer` ([#6204](https://github.com/pyg-team/pytorch_geometric/pull/6204))
- Added the `AirfRANS` dataset ([#6287](https://github.com/pyg-team/pytorch_geometric/pull/6287))
- Added `AttentionExplainer` ([#6279](https://github.com/pyg-team/pytorch_geometric/pull/6279))
Expand Down Expand Up @@ -35,6 +36,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for dropping nodes in `utils.to_dense_batch` in case `max_num_nodes` is smaller than the number of nodes ([#6124](https://github.com/pyg-team/pytorch_geometric/pull/6124))
- Added the RandLA-Net architecture as an example ([#5117](https://github.com/pyg-team/pytorch_geometric/pull/5117))
### Changed
- Fixed the filtering of node features in `transforms.RemoveIsolatedNodes` ([#6308](https://github.com/pyg-team/pytorch_geometric/pull/6308))
- Fixed a bug in `DimeNet` that causes a output dimension mismatch ([#6305](https://github.com/pyg-team/pytorch_geometric/pull/6305))
- Fixed `Data.to_heterogeneous()` with empty `edge_index` ([#6304](https://github.com/pyg-team/pytorch_geometric/pull/6304))
- Unify `Explanation.node_mask` and `Explanation.node_feat_mask` ([#6267](https://github.com/pyg-team/pytorch_geometric/pull/6267))
Expand Down
80 changes: 39 additions & 41 deletions examples/multi_gpu/distributed_sampling.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import os

import torch
Expand All @@ -6,9 +7,9 @@
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel
from tqdm import tqdm

import copy
from torch_geometric.datasets import Reddit
from torch_geometric.loader import NeighborSampler
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import SAGEConv


Expand All @@ -24,37 +25,33 @@ def __init__(self, in_channels, hidden_channels, out_channels,
self.convs.append(SAGEConv(hidden_channels, hidden_channels))
self.convs.append(SAGEConv(hidden_channels, out_channels))

def forward(self, x, adjs):
for i, (edge_index, _, size) in enumerate(adjs):
x_target = x[:size[1]] # Target nodes are always placed first.
x = self.convs[i]((x, x_target), edge_index)
if i != self.num_layers - 1:
x = F.relu(x)
def forward(self, x, edge_index):
for i, conv in enumerate(self.convs):
x = conv(x, edge_index)
if i < len(self.convs) - 1:
x = x.relu_()
x = F.dropout(x, p=0.5, training=self.training)
return x.log_softmax(dim=-1)
return x

@torch.no_grad()
def inference(self, x_all, device, subgraph_loader):
pbar = tqdm(total=x_all.size(0) * self.num_layers)
def inference(self, x_all, rank, subgraph_loader):
pbar = tqdm(total=len(subgraph_loader.dataset) * len(self.convs))
pbar.set_description('Evaluating')

for i in range(self.num_layers):
# Compute representations of nodes layer by layer, using *all*
# available edges. This leads to faster computation in contrast to
# immediately computing the final representations of each batch:
for i, conv in enumerate(self.convs):
xs = []
for batch_size, n_id, adj in subgraph_loader:
edge_index, _, size = adj.to(device)
x = x_all[n_id].to(device)
x_target = x[:size[1]]
x = self.convs[i]((x, x_target), edge_index)
if i != self.num_layers - 1:
x = F.relu(x)
xs.append(x.cpu())

pbar.update(batch_size)

for batch in subgraph_loader:
x = x_all[batch.n_id.to(x_all.device)].to(rank)
x = conv(x, batch.edge_index.to(rank))
if i < len(self.convs) - 1:
x = x.relu_()
xs.append(x[:batch.batch_size].cpu())
pbar.update(batch.batch_size)
x_all = torch.cat(xs, dim=0)

pbar.close()

return x_all


Expand All @@ -64,34 +61,35 @@ def run(rank, world_size, dataset):
dist.init_process_group('nccl', rank=rank, world_size=world_size)

data = dataset[0]
data = data.to(rank, 'x', 'y')

train_idx = data.train_mask.nonzero(as_tuple=False).view(-1)
train_idx = train_idx.split(train_idx.size(0) // world_size)[rank]

train_loader = NeighborSampler(data.edge_index, node_idx=train_idx,
sizes=[25, 10], batch_size=1024,
shuffle=True, num_workers=0)
kwargs = {'batch_size': 1024, 'num_workers': 0}
train_loader = NeighborLoader(data, input_nodes=data.train_mask,
num_neighbors=[25, 10],
shuffle=True, drop_last=True, **kwargs)

if rank == 0:
subgraph_loader = NeighborSampler(data.edge_index, node_idx=None,
sizes=[-1], batch_size=2048,
shuffle=False, num_workers=6)
subgraph_loader = NeighborLoader(copy.copy(data), input_nodes=None,
num_neighbors=[-1], shuffle=False, **kwargs)
# No need to maintain these features during evaluation:
del subgraph_loader.data.x, subgraph_loader.data.y
# Add global node index information.
subgraph_loader.data.num_nodes = data.num_nodes
subgraph_loader.data.n_id = torch.arange(data.num_nodes)

torch.manual_seed(12345)
model = SAGE(dataset.num_features, 256, dataset.num_classes).to(rank)
model = DistributedDataParallel(model, device_ids=[rank])
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

x, y = data.x.to(rank), data.y.to(rank)

for epoch in range(1, 21):
model.train()

for batch_size, n_id, adjs in train_loader:
adjs = [adj.to(rank) for adj in adjs]

for batch in train_loader:
optimizer.zero_grad()
out = model(x[n_id], adjs)
loss = F.nll_loss(out, y[n_id[:batch_size]])
out = model(batch.x, batch.edge_index.to(rank))[:batch.batch_size]
loss = F.cross_entropy(out, batch.y[:batch.batch_size])
loss.backward()
optimizer.step()

Expand All @@ -103,7 +101,7 @@ def run(rank, world_size, dataset):
if rank == 0 and epoch % 5 == 0: # We evaluate on a single GPU for now
model.eval()
with torch.no_grad():
out = model.module.inference(x, rank, subgraph_loader)
out = model.module.inference(data.x, rank, subgraph_loader)
res = out.argmax(dim=-1) == data.y
acc1 = int(res[data.train_mask].sum()) / int(data.train_mask.sum())
acc2 = int(res[data.val_mask].sum()) / int(data.val_mask.sum())
Expand Down
120 changes: 0 additions & 120 deletions examples/multi_gpu/distributed_sampling_loader.py

This file was deleted.

0 comments on commit 260a1c2

Please sign in to comment.