-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
147 additions
and
110 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import torch | ||
|
||
from torch_geometric.data import Data, HeteroData | ||
from torch_geometric.loader import RandomNodeLoader | ||
|
||
|
||
def get_edge_index(num_src_nodes, num_dst_nodes, num_edges): | ||
row = torch.randint(num_src_nodes, (num_edges, ), dtype=torch.long) | ||
col = torch.randint(num_dst_nodes, (num_edges, ), dtype=torch.long) | ||
return torch.stack([row, col], dim=0) | ||
|
||
|
||
def test_random_node_loader(): | ||
data = Data() | ||
data.x = torch.randn(100, 128) | ||
data.node_id = torch.arange(100) | ||
data.edge_index = get_edge_index(100, 100, 500) | ||
data.edge_attr = torch.randn(500, 32) | ||
|
||
loader = RandomNodeLoader(data, num_parts=4, shuffle=True) | ||
assert len(loader) == 4 | ||
|
||
for batch in loader: | ||
assert len(batch) == 4 | ||
assert batch.node_id.min() >= 0 | ||
assert batch.node_id.max() < 100 | ||
assert batch.edge_index.size(1) == batch.edge_attr.size(0) | ||
assert torch.allclose(batch.x, data.x[batch.node_id]) | ||
batch.validate() | ||
|
||
|
||
def test_heterogeneous_random_node_loader(): | ||
data = HeteroData() | ||
data['paper'].x = torch.randn(100, 128) | ||
data['paper'].node_id = torch.arange(100) | ||
data['author'].x = torch.randn(200, 128) | ||
data['author'].node_id = torch.arange(200) | ||
data['paper', 'author'].edge_index = get_edge_index(100, 200, 500) | ||
data['paper', 'author'].edge_attr = torch.randn(500, 32) | ||
data['author', 'paper'].edge_index = get_edge_index(200, 100, 400) | ||
data['author', 'paper'].edge_attr = torch.randn(400, 32) | ||
data['paper', 'paper'].edge_index = get_edge_index(100, 100, 600) | ||
data['paper', 'paper'].edge_attr = torch.randn(600, 32) | ||
|
||
loader = RandomNodeLoader(data, num_parts=4, shuffle=True) | ||
assert len(loader) == 4 | ||
|
||
for batch in loader: | ||
assert len(batch) == 4 | ||
assert batch.node_types == data.node_types | ||
assert batch.edge_types == data.edge_types | ||
batch.validate() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import math | ||
from typing import Union | ||
|
||
import torch | ||
from torch import Tensor | ||
|
||
from torch_geometric.data import Data, HeteroData | ||
from torch_geometric.data.hetero_data import to_homogeneous_edge_index | ||
|
||
|
||
class RandomNodeLoader(torch.utils.data.DataLoader): | ||
r"""A data loader that randomly samples nodes within a graph and returns | ||
their induced subgraph. | ||
.. note:: | ||
For an example of using | ||
:class:`~torch_geometric.loader.RandomNodeLoader`, see | ||
`examples/ogbn_proteins_deepgcn.py | ||
<https://github.com/pyg-team/pytorch_geometric/blob/master/examples/ | ||
ogbn_proteins_deepgcn.py>`_. | ||
Args: | ||
data (torch_geometric.data.Data or torch_geometric.data.HeteroData): | ||
The :class:`~torch_geometric.data.Data` or | ||
:class:`~torch_geometric.data.HeteroData` graph object. | ||
num_parts (int): The number of partitions. | ||
**kwargs (optional): Additional arguments of | ||
:class:`torch.utils.data.DataLoader`, such as :obj:`num_workers`. | ||
""" | ||
def __init__( | ||
self, | ||
data: Union[Data, HeteroData], | ||
num_parts: int, | ||
**kwargs, | ||
): | ||
self.data = data | ||
self.num_parts = num_parts | ||
|
||
if isinstance(data, HeteroData): | ||
edge_index, node_dict, edge_dict = to_homogeneous_edge_index(data) | ||
self.node_dict, self.edge_dict = node_dict, edge_dict | ||
else: | ||
edge_index = data.edge_index | ||
|
||
self.edge_index = edge_index | ||
self.num_nodes = data.num_nodes | ||
|
||
super().__init__( | ||
range(self.num_nodes), | ||
batch_size=math.ceil(self.num_nodes / num_parts), | ||
collate_fn=self.collate_fn, | ||
**kwargs, | ||
) | ||
|
||
def collate_fn(self, index): | ||
if not isinstance(index, Tensor): | ||
index = torch.tensor(index) | ||
|
||
if isinstance(self.data, Data): | ||
return self.data.subgraph(index) | ||
|
||
elif isinstance(self.data, HeteroData): | ||
node_dict = { | ||
key: index[(index >= start) & (index < end)] - start | ||
for key, (start, end) in self.node_dict.items() | ||
} | ||
return self.data.subgraph(node_dict) |
This file was deleted.
Oops, something went wrong.