Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix LinkNeighborLoader with sorted timestamps #5602

Merged
merged 4 commits into from
Oct 4, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Changed
- Avoid modifying `mode_kwargs` in `MultiAggregation` ([#5601](https://github.com/pyg-team/pytorch_geometric/pull/5601))
- Changed `BatchNorm` to allow for batches of size one during training ([#5530](https://github.com/pyg-team/pytorch_geometric/pull/5530))
- Integrated better temporal sampling support by requiring that local neighborhoods are sorted according to time ([#5516](https://github.com/pyg-team/pytorch_geometric/issues/5516))
- Integrated better temporal sampling support by requiring that local neighborhoods are sorted according to time ([#5516](https://github.com/pyg-team/pytorch_geometric/issues/5516), [#5602](https://github.com/pyg-team/pytorch_geometric/issues/5602))
- Fixed a bug when applying several scalers with `PNAConv` ([#5514](https://github.com/pyg-team/pytorch_geometric/issues/5514))
- Allow `.` in `ParameterDict` key names ([#5494](https://github.com/pyg-team/pytorch_geometric/pull/5494))
- Renamed `drop_unconnected_nodes` to `drop_unconnected_node_types` and `drop_orig_edges` to `drop_orig_edge_types` in `AddMetapaths` ([#5490](https://github.com/pyg-team/pytorch_geometric/pull/5490))
Expand Down
24 changes: 16 additions & 8 deletions test/loader/test_link_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,18 +194,26 @@ def test_temporal_heterogeneous_link_neighbor_loader():
data['author', 'paper'].edge_index = get_edge_index(200, 100, 1000)

with pytest.raises(ValueError, match=r"'edge_label_time' was not set.*"):
loader = LinkNeighborLoader(data, num_neighbors=[-1] * 2,
edge_label_index=('paper', 'paper'),
batch_size=32, time_attr='time')
loader = LinkNeighborLoader(
data,
num_neighbors=[-1] * 2,
edge_label_index=('paper', 'paper'),
batch_size=32,
time_attr='time',
)

# With edge_time:
edge_time = torch.arange(data['paper', 'paper'].edge_index.size(1))
paper_time_original = data['paper'].time.clone()
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
loader = LinkNeighborLoader(data, num_neighbors=[-1] * 2,
edge_label_index=('paper', 'paper'),
edge_label_time=edge_time, batch_size=32,
time_attr='time', neg_sampling_ratio=0.5,
num_workers=2)
loader = LinkNeighborLoader(
data,
num_neighbors=[-1] * 2,
edge_label_index=('paper', 'paper'),
edge_label_time=edge_time,
batch_size=32,
time_attr='time',
neg_sampling_ratio=0.5,
)
for batch in loader:
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
author_max = batch['author'].time.max()
edge_max = batch['paper', 'paper'].edge_label_time.max()
Expand Down
5 changes: 3 additions & 2 deletions test/loader/test_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ def test_pyg_lib_homogeneous_neighbor_loader():
seed = torch.arange(10)

sample = torch.ops.pyg.neighbor_sample
out1 = sample(colptr, row, seed, [-1, -1], time=None, csc=True)
out1 = sample(colptr, row, seed, [-1, -1], None, None, True)
sample = torch.ops.torch_sparse.neighbor_sample
out2 = sample(colptr, row, seed, [-1, -1], False, True)

Expand Down Expand Up @@ -494,7 +494,8 @@ def test_pyg_lib_heterogeneous_neighbor_loader():

sample = torch.ops.pyg.hetero_neighbor_sample_cpu
out1 = sample(node_types, edge_types, colptr_dict, row_dict, seed_dict,
num_neighbors_dict, None, True, False, True, False, True)
num_neighbors_dict, None, None, True, False, True, False,
"uniform", True)
sample = torch.ops.torch_sparse.hetero_neighbor_sample
out2 = sample(node_types, edge_types, colptr_dict, row_dict, seed_dict,
num_neighbors_dict, 2, False, True)
Expand Down
80 changes: 40 additions & 40 deletions torch_geometric/sampler/neighbor_sampler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Dict, Optional, Tuple, Union

import torch
from torch_scatter import scatter_min

from torch_geometric.data import Data, HeteroData, remote_backend_utils
from torch_geometric.data.feature_store import FeatureStore
Expand All @@ -15,7 +16,6 @@
from torch_geometric.sampler.utils import (
add_negative_samples,
remap_keys,
set_node_time_dict,
to_csc,
to_hetero_csc,
)
Expand Down Expand Up @@ -234,7 +234,8 @@ def _sample(
self.row_dict,
seed, # seed_dict
self.num_neighbors,
kwargs.get('node_time_dict', self.node_time_dict),
self.node_time_dict,
kwargs.get('seed_time_dict', None),
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
True, # csc
self.replace,
self.directed,
Expand All @@ -249,33 +250,20 @@ def _sample(
node = {k: v[1] for k, v in node.items()}

else:
if self.node_time_dict is None:
out = torch.ops.torch_sparse.hetero_neighbor_sample(
self.node_types,
self.edge_types,
self.colptr_dict,
self.row_dict,
seed, # seed_dict
self.num_neighbors,
self.num_hops,
self.replace,
self.directed,
)
else:
assert self.temporal_strategy == 'uniform'
fn = torch.ops.torch_sparse.hetero_temporal_neighbor_sample
out = fn(
self.node_types,
self.edge_types,
self.colptr_dict,
self.row_dict,
seed, # seed_dict
self.num_neighbors,
kwargs.get('node_time_dict', self.node_time_dict),
self.num_hops,
self.replace,
self.directed,
)
if self.node_time_dict is not None:
raise ValueError("'time_attr' not supported for "
"neighbor sampling via 'torch-sparse'")
out = torch.ops.torch_sparse.hetero_neighbor_sample(
self.node_types,
self.edge_types,
self.colptr_dict,
self.row_dict,
seed, # seed_dict
self.num_neighbors,
self.num_hops,
self.replace,
self.directed,
)
node, row, col, edge, batch = out + (None, )

return HeteroSamplerOutput(
Expand All @@ -296,7 +284,8 @@ def _sample(
self.row,
seed, # seed
self.num_neighbors,
kwargs.get('node_time', self.node_time),
self.node_time,
kwargs.get('seed_time', None),
True, # csc
self.replace,
self.directed,
Expand Down Expand Up @@ -377,7 +366,7 @@ def sample_from_edges(
self.num_dst_nodes, negative_sampling_ratio)
edge_label_index, edge_label, edge_label_time = out

orig_edge_label_index = edge_label_index
seed_time_dict = None
if (self.data_cls == 'custom'
or issubclass(self.data_cls, HeteroData)):
if self.input_type[0] != self.input_type[-1]:
Expand All @@ -390,28 +379,39 @@ def sample_from_edges(
self.input_type[0]: query_src,
self.input_type[-1]: query_dst,
}
if edge_label_time is not None:
seed_time_dict = {
self.input_type[0]:
scatter_min(edge_label_time, reverse_src)[0],
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
self.input_type[-1]:
scatter_min(edge_label_time, reverse_dst)[0],
}

else: # Merge both source and destination node indices:
query_nodes = edge_label_index.view(-1)
query_nodes, reverse = query_nodes.unique(return_inverse=True)
edge_label_index = reverse.view(2, -1)
query_node_dict = {self.input_type[0]: query_nodes}
if edge_label_time is not None:
tmp = torch.cat([edge_label_time, edge_label_time])
seed_time_dict = {
self.input_type[0]: scatter_min(tmp, reverse)[0]
}

output = self._sample(
seed=query_node_dict,
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
seed_time_dict=seed_time_dict,
)

node_time_dict = self.node_time_dict
if edge_label_time is not None:
node_time_dict = set_node_time_dict(
node_time_dict, self.input_type, orig_edge_label_index,
edge_label_time, self.num_src_nodes, self.num_dst_nodes)

output = self._sample(seed=query_node_dict,
node_time_dict=node_time_dict)
output.metadata = (edge_label_index, edge_label, edge_label_time)

elif issubclass(self.data_cls, Data):
assert self.node_time is None # TODO
query_nodes = edge_label_index.view(-1)
query_nodes, reverse = query_nodes.unique(return_inverse=True)
edge_label_index = reverse.view(2, -1)

output = self._sample(seed=query_nodes)
output = self._sample(seed=query_nodes, seed_time=None)
output.metadata = (edge_label_index, edge_label)

else:
Expand Down
34 changes: 3 additions & 31 deletions torch_geometric/sampler/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import copy
from typing import Any, Dict, Optional, Set, Tuple, Union

import torch
from torch import Tensor
from torch_scatter import scatter_min

from torch_geometric.data import Data, HeteroData
from torch_geometric.data.storage import EdgeStorage
from torch_geometric.typing import EdgeType, NodeType, OptTensor
from torch_geometric.typing import NodeType, OptTensor

# Edge Layout Conversion ######################################################

Expand All @@ -24,8 +22,8 @@ def sort_csc(
# Multiplying by raw `datetime[64ns]` values may cause overflows.
# As such, we normalize time into range [0, 1) before sorting:
src_node_time = src_node_time.to(torch.double, copy=True)
min_time, max_time = src_node_time.min(), src_node_time.max() + 1
src_node_time.sub_(min_time).div_(max_time)
src_node_time.sub_(src_node_time.min())
src_node_time.div_(src_node_time.max() + 1)

perm = src_node_time[row].add_(col.to(torch.double)).argsort()
return row[perm], col[perm], perm
Expand Down Expand Up @@ -151,32 +149,6 @@ def add_negative_samples(
return edge_label_index, edge_label, edge_label_time


def set_node_time_dict(
node_time_dict,
input_type: EdgeType,
edge_label_index,
edge_label_time,
num_src_nodes: int,
num_dst_nodes: int,
):
"""For edges in a batch replace `src` and `dst` node times by the min
across all edge times."""
def update_time_(node_time_dict, index, node_type, num_nodes):
node_time_dict[node_type] = node_time_dict[node_type].clone()
node_time, _ = scatter_min(edge_label_time, index, dim=0,
dim_size=num_nodes)
# NOTE We assume that node_time is always less than edge_time.
index_unique = index.unique()
node_time_dict[node_type][index_unique] = node_time[index_unique]

node_time_dict = copy.copy(node_time_dict)
update_time_(node_time_dict, edge_label_index[0], input_type[0],
num_src_nodes)
update_time_(node_time_dict, edge_label_index[1], input_type[-1],
num_dst_nodes)
return node_time_dict


###############################################################################


Expand Down