Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 28, 2022
1 parent 255bfd7 commit 3b24091
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
10 changes: 6 additions & 4 deletions examples/multi_gpu/distributed_sampling.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import os

import torch
Expand All @@ -6,7 +7,7 @@
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel
from tqdm import tqdm
import copy

from torch_geometric.datasets import Reddit
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import SAGEConv
Expand Down Expand Up @@ -64,12 +65,13 @@ def run(rank, world_size, dataset):

kwargs = {'batch_size': 1024, 'num_workers': 0}
train_loader = NeighborLoader(data, input_nodes=data.train_mask,
num_neighbors=[25, 10],
shuffle=True, drop_last=True, **kwargs)
num_neighbors=[25, 10], shuffle=True,
drop_last=True, **kwargs)

if rank == 0:
subgraph_loader = NeighborLoader(copy.copy(data), input_nodes=None,
num_neighbors=[-1], shuffle=False, **kwargs)
num_neighbors=[-1], shuffle=False,
**kwargs)
# No need to maintain these features during evaluation:
del subgraph_loader.data.x, subgraph_loader.data.y
# Add global node index information.
Expand Down
2 changes: 1 addition & 1 deletion training_benchmark.sh
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# python training_benchmark.py --datasets Reddit --models gcn
python examples/multi_gpu/distributed_sampling.py >
python examples/multi_gpu/distributed_sampling.py >

0 comments on commit 3b24091

Please sign in to comment.