Skip to content

Commit

Permalink
HeteroData support in RandomNodeLoader (#6007)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Nov 18, 2022
1 parent 434b520 commit 009abc2
Show file tree
Hide file tree
Showing 9 changed files with 147 additions and 110 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.2.0] - 2022-MM-DD
### Added
- Added `HeteroData` support in `RandomNodeLoader` ([#6007](https://github.com/pyg-team/pytorch_geometric/pull/6007))
- Added bipartite `GraphSAGE` example ([#5834](https://github.com/pyg-team/pytorch_geometric/pull/5834))
- Added `LRGBDataset` to include 5 datasets from the [Long Range Graph Benchmark](https://openreview.net/pdf?id=in7XC5RcjEn) ([#5935](https://github.com/pyg-team/pytorch_geometric/pull/5935))
- Added a warning for invalid node and edge type names in `HeteroData` ([#5990](https://github.com/pyg-team/pytorch_geometric/pull/5990))
Expand Down
8 changes: 4 additions & 4 deletions examples/ogbn_proteins_deepgcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch_scatter import scatter
from tqdm import tqdm

from torch_geometric.loader import RandomNodeSampler
from torch_geometric.loader import RandomNodeLoader
from torch_geometric.nn import DeepGCNLayer, GENConv

dataset = PygNodePropPredDataset('ogbn-proteins', root='../data')
Expand All @@ -24,9 +24,9 @@
mask[splitted_idx[split]] = True
data[f'{split}_mask'] = mask

train_loader = RandomNodeSampler(data, num_parts=40, shuffle=True,
num_workers=5)
test_loader = RandomNodeSampler(data, num_parts=5, num_workers=5)
train_loader = RandomNodeLoader(data, num_parts=40, shuffle=True,
num_workers=5)
test_loader = RandomNodeLoader(data, num_parts=5, num_workers=5)


class DeeperGCN(torch.nn.Module):
Expand Down
8 changes: 4 additions & 4 deletions examples/rev_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from tqdm import tqdm

import torch_geometric.transforms as T
from torch_geometric.loader import RandomNodeSampler
from torch_geometric.loader import RandomNodeLoader
from torch_geometric.nn import GroupAddRev, SAGEConv
from torch_geometric.utils import index_to_mask

Expand Down Expand Up @@ -91,11 +91,11 @@ def forward(self, x, edge_index):
for split in ['train', 'valid', 'test']:
data[f'{split}_mask'] = index_to_mask(split_idx[split], data.y.shape[0])

train_loader = RandomNodeSampler(data, num_parts=10, shuffle=True,
num_workers=5)
train_loader = RandomNodeLoader(data, num_parts=10, shuffle=True,
num_workers=5)
# Increase the num_parts of the test loader if you cannot fit
# the full batch graph into your GPU:
test_loader = RandomNodeSampler(data, num_parts=1, num_workers=5)
test_loader = RandomNodeLoader(data, num_parts=1, num_workers=5)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = RevGNN(
Expand Down
52 changes: 52 additions & 0 deletions test/loader/test_random_node_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import torch

from torch_geometric.data import Data, HeteroData
from torch_geometric.loader import RandomNodeLoader


def get_edge_index(num_src_nodes, num_dst_nodes, num_edges):
row = torch.randint(num_src_nodes, (num_edges, ), dtype=torch.long)
col = torch.randint(num_dst_nodes, (num_edges, ), dtype=torch.long)
return torch.stack([row, col], dim=0)


def test_random_node_loader():
data = Data()
data.x = torch.randn(100, 128)
data.node_id = torch.arange(100)
data.edge_index = get_edge_index(100, 100, 500)
data.edge_attr = torch.randn(500, 32)

loader = RandomNodeLoader(data, num_parts=4, shuffle=True)
assert len(loader) == 4

for batch in loader:
assert len(batch) == 4
assert batch.node_id.min() >= 0
assert batch.node_id.max() < 100
assert batch.edge_index.size(1) == batch.edge_attr.size(0)
assert torch.allclose(batch.x, data.x[batch.node_id])
batch.validate()


def test_heterogeneous_random_node_loader():
data = HeteroData()
data['paper'].x = torch.randn(100, 128)
data['paper'].node_id = torch.arange(100)
data['author'].x = torch.randn(200, 128)
data['author'].node_id = torch.arange(200)
data['paper', 'author'].edge_index = get_edge_index(100, 200, 500)
data['paper', 'author'].edge_attr = torch.randn(500, 32)
data['author', 'paper'].edge_index = get_edge_index(200, 100, 400)
data['author', 'paper'].edge_attr = torch.randn(400, 32)
data['paper', 'paper'].edge_index = get_edge_index(100, 100, 600)
data['paper', 'paper'].edge_attr = torch.randn(600, 32)

loader = RandomNodeLoader(data, num_parts=4, shuffle=True)
assert len(loader) == 4

for batch in loader:
assert len(batch) == 4
assert batch.node_types == data.node_types
assert batch.edge_types == data.edge_types
batch.validate()
6 changes: 3 additions & 3 deletions torch_geometric/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from torch_geometric.loader import GraphSAINTEdgeSampler # noqa
from torch_geometric.loader import GraphSAINTRandomWalkSampler # noqa
from torch_geometric.loader import ShaDowKHopSampler # noqa
from torch_geometric.loader import RandomNodeSampler # noqa
from torch_geometric.loader import RandomNodeLoader # noqa
from torch_geometric.loader import DataLoader # noqa
from torch_geometric.loader import DataListLoader # noqa
from torch_geometric.loader import DenseDataLoader # noqa
Expand All @@ -66,8 +66,8 @@
'data.GraphSAINTRandomWalkSampler')(GraphSAINTRandomWalkSampler)
ShaDowKHopSampler = deprecated("use 'loader.ShaDowKHopSampler' instead",
'data.ShaDowKHopSampler')(ShaDowKHopSampler)
RandomNodeSampler = deprecated("use 'loader.RandomNodeSampler' instead",
'data.RandomNodeSampler')(RandomNodeSampler)
RandomNodeSampler = deprecated("use 'loader.RandomNodeLoader' instead",
'data.RandomNodeSampler')(RandomNodeLoader)
DataLoader = deprecated("use 'loader.DataLoader' instead",
'data.DataLoader')(DataLoader)
DataListLoader = deprecated("use 'loader.DataListLoader' instead",
Expand Down
12 changes: 6 additions & 6 deletions torch_geometric/graphgym/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
GraphSAINTNodeSampler,
GraphSAINTRandomWalkSampler,
NeighborSampler,
RandomNodeSampler,
RandomNodeLoader,
)
from torch_geometric.utils import (
index_to_mask,
Expand Down Expand Up @@ -256,11 +256,11 @@ def get_loader(dataset, sampler, batch_size, shuffle=True):
batch_size=batch_size, shuffle=shuffle,
num_workers=cfg.num_workers, pin_memory=True)
elif sampler == "random_node":
loader_train = RandomNodeSampler(dataset[0],
num_parts=cfg.train.train_parts,
shuffle=shuffle,
num_workers=cfg.num_workers,
pin_memory=True)
loader_train = RandomNodeLoader(dataset[0],
num_parts=cfg.train.train_parts,
shuffle=shuffle,
num_workers=cfg.num_workers,
pin_memory=True)
elif sampler == "saint_rw":
loader_train = \
GraphSAINTRandomWalkSampler(dataset[0],
Expand Down
11 changes: 9 additions & 2 deletions torch_geometric/loader/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from torch_geometric.deprecation import deprecated

from .dataloader import DataLoader
from .neighbor_loader import NeighborLoader
from .link_neighbor_loader import LinkNeighborLoader
Expand All @@ -6,7 +8,7 @@
from .graph_saint import (GraphSAINTSampler, GraphSAINTNodeSampler,
GraphSAINTEdgeSampler, GraphSAINTRandomWalkSampler)
from .shadow import ShaDowKHopSampler
from .random_node_sampler import RandomNodeSampler
from .random_node_loader import RandomNodeLoader
from .data_list_loader import DataListLoader
from .dense_data_loader import DenseDataLoader
from .temporal_dataloader import TemporalDataLoader
Expand All @@ -30,11 +32,16 @@
'GraphSAINTEdgeSampler',
'GraphSAINTRandomWalkSampler',
'ShaDowKHopSampler',
'RandomNodeSampler',
'RandomNodeLoader',
'DataListLoader',
'DenseDataLoader',
'TemporalDataLoader',
'NeighborSampler',
'ImbalancedSampler',
'DynamicBatchSampler',
]

RandomNodeSampler = deprecated(
details="use 'loader.RandomNodeLoader' instead",
func_name='loader.RandomNodeSampler',
)(RandomNodeLoader)
68 changes: 68 additions & 0 deletions torch_geometric/loader/random_node_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import math
from typing import Union

import torch
from torch import Tensor

from torch_geometric.data import Data, HeteroData
from torch_geometric.data.hetero_data import to_homogeneous_edge_index


class RandomNodeLoader(torch.utils.data.DataLoader):
r"""A data loader that randomly samples nodes within a graph and returns
their induced subgraph.
.. note::
For an example of using
:class:`~torch_geometric.loader.RandomNodeLoader`, see
`examples/ogbn_proteins_deepgcn.py
<https://github.com/pyg-team/pytorch_geometric/blob/master/examples/
ogbn_proteins_deepgcn.py>`_.
Args:
data (torch_geometric.data.Data or torch_geometric.data.HeteroData):
The :class:`~torch_geometric.data.Data` or
:class:`~torch_geometric.data.HeteroData` graph object.
num_parts (int): The number of partitions.
**kwargs (optional): Additional arguments of
:class:`torch.utils.data.DataLoader`, such as :obj:`num_workers`.
"""
def __init__(
self,
data: Union[Data, HeteroData],
num_parts: int,
**kwargs,
):
self.data = data
self.num_parts = num_parts

if isinstance(data, HeteroData):
edge_index, node_dict, edge_dict = to_homogeneous_edge_index(data)
self.node_dict, self.edge_dict = node_dict, edge_dict
else:
edge_index = data.edge_index

self.edge_index = edge_index
self.num_nodes = data.num_nodes

super().__init__(
range(self.num_nodes),
batch_size=math.ceil(self.num_nodes / num_parts),
collate_fn=self.collate_fn,
**kwargs,
)

def collate_fn(self, index):
if not isinstance(index, Tensor):
index = torch.tensor(index)

if isinstance(self.data, Data):
return self.data.subgraph(index)

elif isinstance(self.data, HeteroData):
node_dict = {
key: index[(index >= start) & (index < end)] - start
for key, (start, end) in self.node_dict.items()
}
return self.data.subgraph(node_dict)
91 changes: 0 additions & 91 deletions torch_geometric/loader/random_node_sampler.py

This file was deleted.

0 comments on commit 009abc2

Please sign in to comment.