Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for input_time in NeighborLoader #5763

Merged
merged 6 commits into from
Oct 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.2.0] - 2022-MM-DD
### Added
- Added support for `input_time` in `NeighborLoader` ([#5763](https://github.com/pyg-team/pytorch_geometric/pull/5763))
- Added `disjoint` mode for temporal `LinkNeighborLoader` ([#5717](https://github.com/pyg-team/pytorch_geometric/pull/5717))
- Added `HeteroData` support for `transforms.Constant` ([#5700](https://github.com/pyg-team/pytorch_geometric/pull/5700))
- Added `np.memmap` support in `NeighborLoader` ([#5696](https://github.com/pyg-team/pytorch_geometric/pull/5696))
Expand Down
3 changes: 2 additions & 1 deletion test/loader/test_hgt_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,9 @@ def test_hgt_loader():
assert set(batch.node_types) == {'paper', 'author'}
assert set(batch.edge_types) == set(data.edge_types)

assert len(batch['paper']) == 2
assert len(batch['paper']) == 3
assert batch['paper'].x.size() == (40, ) # 20 + 4 * 5
assert batch['paper'].input_nodes.numel() == batch_size
assert batch['paper'].batch_size == batch_size
assert batch['paper'].x.min() >= 0 and batch['paper'].x.max() < 100

Expand Down
14 changes: 8 additions & 6 deletions test/loader/test_link_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,10 @@ def test_homogeneous_link_neighbor_loader(directed, neg_sampling_ratio):
for batch in loader:
assert isinstance(batch, Data)

assert len(batch) == 5
assert len(batch) == 6
assert batch.x.size(0) <= 100
assert batch.x.min() >= 0 and batch.x.max() < 100
assert batch.input_links.numel() == 20
assert batch.edge_index.min() >= 0
assert batch.edge_index.max() < batch.num_nodes
assert batch.edge_attr.min() >= 0
Expand Down Expand Up @@ -110,7 +111,7 @@ def test_heterogeneous_link_neighbor_loader(directed, neg_sampling_ratio):

for batch in loader:
assert isinstance(batch, HeteroData)
assert len(batch) == 5
assert len(batch) == 6
if neg_sampling_ratio == 0.0:
# Assert only positive samples are present in the original graph:
assert batch['paper', 'author'].edge_label.sum() == 0
Expand All @@ -120,7 +121,6 @@ def test_heterogeneous_link_neighbor_loader(directed, neg_sampling_ratio):
assert len(edge_index | edge_label_index) == len(edge_index)

else:

assert batch['paper', 'author'].edge_label_index.size(1) == 40
assert torch.all(batch['paper', 'author'].edge_label[:20] == 1)
assert torch.all(batch['paper', 'author'].edge_label[20:] == 0)
Expand Down Expand Up @@ -195,7 +195,7 @@ def test_temporal_heterogeneous_link_neighbor_loader():
data['paper', 'author'].edge_index = get_edge_index(100, 200, 1000)
data['author', 'paper'].edge_index = get_edge_index(200, 100, 1000)

with pytest.raises(ValueError, match=r"'edge_label_time' was not set.*"):
with pytest.raises(ValueError, match=r"'edge_label_time' is not set"):
loader = LinkNeighborLoader(
data,
num_neighbors=[-1] * 2,
Expand Down Expand Up @@ -312,7 +312,8 @@ def test_homogeneous_link_neighbor_loader_no_edges():

for batch in loader:
assert isinstance(batch, Data)
assert len(batch) == 3
assert len(batch) == 4
assert batch.input_links.numel() == 20
assert batch.num_nodes <= 40
assert batch.edge_label_index.size(1) == 20
assert batch.num_nodes == batch.edge_label_index.unique().numel()
Expand All @@ -328,8 +329,9 @@ def test_heterogeneous_link_neighbor_loader_no_edges():

for batch in loader:
assert isinstance(batch, HeteroData)
assert len(batch) == 3
assert len(batch) == 4
assert batch['paper'].num_nodes <= 40
assert batch['paper', 'paper'].input_links.numel() == 20
assert batch['paper', 'paper'].edge_label_index.size(1) == 20
assert batch['paper'].num_nodes == batch[
'paper', 'paper'].edge_label_index.unique().numel()
10 changes: 5 additions & 5 deletions test/loader/test_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,9 @@ def test_homogeneous_neighbor_loader(directed):

for batch in loader:
assert isinstance(batch, Data)

assert len(batch) == 4
assert len(batch) == 5
assert batch.x.size(0) <= 100
assert batch.batch_size == 20
assert batch.input_nodes.numel() == batch.batch_size == 20
assert batch.x.min() >= 0 and batch.x.max() < 100
assert batch.edge_index.min() >= 0
assert batch.edge_index.max() < batch.num_nodes
Expand Down Expand Up @@ -118,8 +117,9 @@ def test_heterogeneous_neighbor_loader(directed):
# Test node type selection:
assert set(batch.node_types) == {'paper', 'author'}

assert len(batch['paper']) == 2
assert len(batch['paper']) == 3
assert batch['paper'].x.size(0) <= 100
assert batch['paper'].input_nodes.numel() == batch_size
assert batch['paper'].batch_size == batch_size
assert batch['paper'].x.min() >= 0 and batch['paper'].x.max() < 100

Expand Down Expand Up @@ -498,7 +498,7 @@ def test_pyg_lib_heterogeneous_neighbor_loader():
'author__to__paper': [-1, -1],
}

sample = torch.ops.pyg.hetero_neighbor_sample_cpu
sample = torch.ops.pyg.hetero_neighbor_sample
out1 = sample(node_types, edge_types, colptr_dict, row_dict, seed_dict,
num_neighbors_dict, None, None, True, False, True, False,
"uniform", True)
Expand Down
12 changes: 3 additions & 9 deletions torch_geometric/data/lightning_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,10 @@
from torch_geometric.data import Data, Dataset, HeteroData
from torch_geometric.data.feature_store import FeatureStore
from torch_geometric.data.graph_store import GraphStore
from torch_geometric.loader import LinkNeighborLoader, NeighborLoader
from torch_geometric.loader.dataloader import DataLoader
from torch_geometric.loader.link_neighbor_loader import (
LinkNeighborLoader,
get_edge_label_index,
)
from torch_geometric.loader.neighbor_loader import (
NeighborLoader,
NeighborSampler,
get_input_nodes,
)
from torch_geometric.loader.utils import get_edge_label_index, get_input_nodes
from torch_geometric.sampler import NeighborSampler
from torch_geometric.typing import InputEdges, InputNodes

try:
Expand Down
6 changes: 4 additions & 2 deletions torch_geometric/loader/hgt_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,16 +104,18 @@ def __init__(
**kwargs,
):
node_type, _ = get_input_nodes(data, input_nodes)
node_sampler = HGTSampler(

hgt_sampler = HGTSampler(
data,
num_samples=num_samples,
input_type=node_type,
is_sorted=is_sorted,
share_memory=kwargs.get('num_workers', 0) > 0,
)

super().__init__(
data=data,
node_sampler=node_sampler,
node_sampler=hgt_sampler,
input_nodes=input_nodes,
transform=transform,
filter_per_worker=filter_per_worker,
Expand Down
125 changes: 40 additions & 85 deletions torch_geometric/loader/link_loader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Iterator, Tuple, Union
from typing import Any, Callable, Iterator, List, Tuple, Union

import torch

Expand All @@ -7,6 +7,7 @@
from torch_geometric.data.graph_store import GraphStore
from torch_geometric.loader.base import DataLoaderIterator
from torch_geometric.loader.utils import (
InputData,
filter_custom_store,
filter_data,
filter_hetero_data,
Expand Down Expand Up @@ -89,53 +90,57 @@ def __init__(
if 'collate_fn' in kwargs:
del kwargs['collate_fn']

self.data = data

# Initialize sampler with keyword arguments:
# NOTE sampler is an attribute of 'DataLoader', so we use link_sampler
# here:
self.link_sampler = link_sampler

# Store additional arguments:
self.edge_label = edge_label
self.edge_label_index = edge_label_index
self.edge_label_time = edge_label_time
self.transform = transform
self.filter_per_worker = filter_per_worker
self.neg_sampling_ratio = neg_sampling_ratio

# Get input type, or None for homogeneous graphs:
# Get edge type (or `None` for homogeneous graphs):
edge_type, edge_label_index = get_edge_label_index(
data, edge_label_index)
if edge_label is None:
edge_label = torch.zeros(edge_label_index.size(1),
device=edge_label_index.device)
self.input_type = edge_type

super().__init__(
Dataset(edge_label_index, edge_label, edge_label_time),
collate_fn=self.collate_fn,
**kwargs,
self.data = data
self.edge_type = edge_type
self.link_sampler = link_sampler
self.input_data = InputData(edge_label_index[0], edge_label_index[1],
edge_label, edge_label_time)
self.neg_sampling_ratio = neg_sampling_ratio
self.transform = transform
self.filter_per_worker = filter_per_worker

iterator = range(edge_label_index.size(1))
super().__init__(iterator, collate_fn=self.collate_fn, **kwargs)

def collate_fn(self, index: List[int]) -> Any:
r"""Samples a subgraph from a batch of input nodes."""
input_data: EdgeSamplerInput = self.input_data[index]
out = self.link_sampler.sample_from_edges(
input_data,
negative_sampling_ratio=self.neg_sampling_ratio,
)

if self.filter_per_worker: # Execute `filter_fn` in the worker process
out = self.filter_fn(out)

return out

def filter_fn(
self,
out: Union[SamplerOutput, HeteroSamplerOutput],
) -> Union[Data, HeteroData]:
r"""Joins the sampled nodes with their corresponding features,
returning the resulting (Data or HeteroData) object to be used
downstream."""
returning the resulting :class:`~torch_geometric.data.Data` or
:class:`~torch_geometric.data.HeteroData` object to be used downstream.
"""
if isinstance(out, SamplerOutput):
edge_label_index, edge_label, edge_label_time = out.metadata
data = filter_data(self.data, out.node, out.row, out.col, out.edge,
self.link_sampler.edge_permutation)

data.batch = out.batch
data.edge_label_index = edge_label_index
data.edge_label = edge_label
data.edge_label_time = edge_label_time
data.input_links = out.metadata[0]
data.edge_label_index = out.metadata[1]
data.edge_label = out.metadata[2]
data.edge_label_time = out.metadata[3]

elif isinstance(out, HeteroSamplerOutput):
edge_label_index, edge_label, edge_label_time = out.metadata
if isinstance(self.data, HeteroData):
data = filter_hetero_data(self.data, out.node, out.row,
out.col, out.edge,
Expand All @@ -144,75 +149,25 @@ def filter_fn(
data = filter_custom_store(*self.data, out.node, out.row,
out.col, out.edge)

edge_type = self.input_type
for key, batch in (out.batch or {}).items():
data[key].batch = batch
data[edge_type].edge_label_index = edge_label_index
data[edge_type].edge_label = edge_label
if edge_label_time is not None:
data[edge_type].edge_label_time = edge_label_time
data[self.edge_type].input_links = out.metadata[0]
data[self.edge_type].edge_label_index = out.metadata[1]
data[self.edge_type].edge_label = out.metadata[2]
data[self.edge_type].edge_label_time = out.metadata[3]

else:
raise TypeError(f"'{self.__class__.__name__}'' found invalid "
f"type: '{type(out)}'")

return data if self.transform is None else self.transform(data)

def collate_fn(self, index: EdgeSamplerInput) -> Any:
r"""Samples a subgraph from a batch of input nodes."""
out = self.link_sampler.sample_from_edges(
index,
negative_sampling_ratio=self.neg_sampling_ratio,
)
if self.filter_per_worker:
# We execute `filter_fn` in the worker process.
out = self.filter_fn(out)
return out

def _get_iterator(self) -> Iterator:
if self.filter_per_worker:
return super()._get_iterator()
# We execute `filter_fn` in the main process.

# Execute `filter_fn` in the main process:
return DataLoaderIterator(super()._get_iterator(), self.filter_fn)

def __repr__(self) -> str:
return f'{self.__class__.__name__}()'


###############################################################################


class Dataset(torch.utils.data.Dataset):
def __init__(
self,
edge_label_index: torch.Tensor,
edge_label: torch.Tensor,
edge_label_time: OptTensor = None,
):
# NOTE see documentation of LinkLoader for details on these three
# input parameters:
self.edge_label_index = edge_label_index
self.edge_label = edge_label
self.edge_label_time = edge_label_time

def __getitem__(
self,
idx: int,
) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
if self.edge_label_time is None:
return (
self.edge_label_index[0, idx],
self.edge_label_index[1, idx],
self.edge_label[idx],
)
else:
return (
self.edge_label_index[0, idx],
self.edge_label_index[1, idx],
self.edge_label[idx],
self.edge_label_time[idx],
)

def __len__(self) -> int:
return self.edge_label_index.size(1)
19 changes: 7 additions & 12 deletions torch_geometric/loader/link_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,21 +166,16 @@ def __init__(
neighbor_sampler: Optional[NeighborSampler] = None,
**kwargs,
):
# Get input type:
# TODO(manan): this computation is required twice, once here and once
# in LinkLoader:
# TODO(manan): Avoid duplicated computation (here and in NodeLoader):
edge_type, _ = get_edge_label_index(data, edge_label_index)

has_time_attr = time_attr is not None
has_edge_label_time = edge_label_time is not None
if has_edge_label_time != has_time_attr:
if (edge_label_time is not None) != (time_attr is not None):
raise ValueError(
f"Received conflicting 'time_attr' and 'edge_label_time' "
f"arguments: 'time_attr' was "
f"{'set' if has_time_attr else 'not set'} and "
f"'edge_label_time' was "
f"{'set' if has_edge_label_time else 'not set'}. Please "
f"resolve these conflicting arguments.")
f"Received conflicting 'edge_label_time' and 'time_attr' "
f"arguments: 'edge_label_time' is "
f"{'set' if edge_label_time is not None else 'not set'} "
f"while 'input_time' is "
f"{'set' if time_attr is not None else 'not set'}.")

if neighbor_sampler is None:
neighbor_sampler = NeighborSampler(
Expand Down
Loading