Skip to content

Commit

Permalink
NeighborLoader: support temporal sampling with `(FeatureStore, Grap…
Browse files Browse the repository at this point in the history
…hStore)` (#4929)
  • Loading branch information
mananshah99 authored Jul 7, 2022
1 parent fdb1ab0 commit db5e6d9
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 10 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `time_attr` argument to `LinkNeighborLoader` ([#4877](https://github.com/pyg-team/pytorch_geometric/pull/4877), [#4908](https://github.com/pyg-team/pytorch_geometric/pull/4908))
- Added a `filter_per_worker` argument to data loaders to allow filtering of data within sub-processes ([#4873](https://github.com/pyg-team/pytorch_geometric/pull/4873))
- Added a `NeighborLoader` benchmark script ([#4815](https://github.com/pyg-team/pytorch_geometric/pull/4815))
- Added support for `FeatureStore` and `GraphStore` in `NeighborLoader` ([#4817](https://github.com/pyg-team/pytorch_geometric/pull/4817), [#4851](https://github.com/pyg-team/pytorch_geometric/pull/4851), [#4854](https://github.com/pyg-team/pytorch_geometric/pull/4854), [#4856](https://github.com/pyg-team/pytorch_geometric/pull/4856), [#4857](https://github.com/pyg-team/pytorch_geometric/pull/4857), [#4882](https://github.com/pyg-team/pytorch_geometric/pull/4882), [#4883](https://github.com/pyg-team/pytorch_geometric/pull/4883), [#4992](https://github.com/pyg-team/pytorch_geometric/pull/4922))
- Added support for `FeatureStore` and `GraphStore` in `NeighborLoader` ([#4817](https://github.com/pyg-team/pytorch_geometric/pull/4817), [#4851](https://github.com/pyg-team/pytorch_geometric/pull/4851), [#4854](https://github.com/pyg-team/pytorch_geometric/pull/4854), [#4856](https://github.com/pyg-team/pytorch_geometric/pull/4856), [#4857](https://github.com/pyg-team/pytorch_geometric/pull/4857), [#4882](https://github.com/pyg-team/pytorch_geometric/pull/4882), [#4883](https://github.com/pyg-team/pytorch_geometric/pull/4883), [#4929](https://github.com/pyg-team/pytorch_geometric/pull/4929), [#4992](https://github.com/pyg-team/pytorch_geometric/pull/4922))
- Added a `normalize` parameter to `dense_diff_pool` ([#4847](https://github.com/pyg-team/pytorch_geometric/pull/4847))
- Added `size=None` explanation to jittable `MessagePassing` modules in the documentation ([#4850](https://github.com/pyg-team/pytorch_geometric/pull/4850))
- Added documentation to the `DataLoaderIterator` class ([#4838](https://github.com/pyg-team/pytorch_geometric/pull/4838))
Expand Down
45 changes: 45 additions & 0 deletions test/loader/test_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch_sparse import SparseTensor

from torch_geometric.data import Data, HeteroData
from torch_geometric.data.feature_store import TensorAttr
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import GraphConv, to_hetero
from torch_geometric.testing import withRegisteredOp
Expand Down Expand Up @@ -359,3 +360,47 @@ def test_custom_neighbor_loader(FeatureStore, GraphStore):
'paper', 'to', 'author'].edge_index.size())
assert (batch1['author', 'to', 'paper'].edge_index.size() == batch1[
'author', 'to', 'paper'].edge_index.size())


@withRegisteredOp('torch_sparse.hetero_temporal_neighbor_sample')
@pytest.mark.parametrize('FeatureStore', [MyFeatureStore, HeteroData])
@pytest.mark.parametrize('GraphStore', [MyGraphStore, HeteroData])
def test_temporal_custom_neighbor_loader_on_cora(get_dataset, FeatureStore,
GraphStore):
# Initialize dataset (once):
dataset = get_dataset(name='Cora')
data = dataset[0]

# Initialize feature store, graph store, and reference:
feature_store = FeatureStore()
graph_store = GraphStore()
hetero_data = HeteroData()

feature_store.put_tensor(data.x, group_name='paper', attr_name='x',
index=None)
hetero_data['paper'].x = data.x

feature_store.put_tensor(torch.arange(data.num_nodes), group_name='paper',
attr_name='time', index=None)
hetero_data['paper'].time = torch.arange(data.num_nodes)

num_nodes = data.x.size(dim=0)
graph_store.put_edge_index(edge_index=data.edge_index,
edge_type=('paper', 'to', 'paper'),
layout='coo', size=(num_nodes, num_nodes))
hetero_data['paper', 'to', 'paper'].edge_index = data.edge_index

loader1 = NeighborLoader(hetero_data, num_neighbors=[-1, -1],
input_nodes='paper', time_attr='time',
batch_size=128)

loader2 = NeighborLoader(
(feature_store, graph_store),
num_neighbors=[-1, -1],
input_nodes=TensorAttr(group_name='paper', attr_name='x'),
time_attr='time',
batch_size=128,
)

for batch1, batch2 in zip(loader1, loader2):
assert torch.equal(batch1['paper'].time, batch2['paper'].time)
13 changes: 13 additions & 0 deletions torch_geometric/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,10 +874,23 @@ def get_all_edge_attrs(self) -> List[EdgeAttr]:
in `Data` and their layouts"""
if not hasattr(self, '_edge_attrs'):
return []
added_attrs = set()

# Check edges added via _put_edge_index:
edge_attrs = self._edge_attrs.values()
for attr in edge_attrs:
attr.size = (self.num_nodes, self.num_nodes)
added_attrs.add(attr.layout)

# Check edges added through regular interface:
# TODO deprecate this and store edge attributes for all edges in
# EdgeStorage
for layout, attr_name in EDGE_LAYOUT_TO_ATTR_NAME.items():
if attr_name in self and layout not in added_attrs:
edge_attrs.append(
EdgeAttr(edge_type=None, layout=layout,
size=(self.num_nodes, self.num_nodes)))

return edge_attrs


Expand Down
21 changes: 19 additions & 2 deletions torch_geometric/data/hetero_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,7 @@ def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool:
out = self._node_store_dict.get(attr.group_name, None)
if out:
# Group name exists, handle index or create new attribute name:
val = getattr(out, attr.attr_name)
val = getattr(out, attr.attr_name, None)
if val is not None:
val[attr.index] = tensor
else:
Expand Down Expand Up @@ -787,13 +787,30 @@ def get_all_edge_attrs(self) -> List[EdgeAttr]:
r"""Returns a list of `EdgeAttr` objects corresponding to the edge
indices stored in `HeteroData` and their layouts."""
out = []
for edge_type, edge_store in self.edge_items():
added_attrs = set()

# Check edges added via _put_edge_index:
for edge_type, _ in self.edge_items():
if not hasattr(self[edge_type], '_edge_attrs'):
continue
edge_attrs = self[edge_type]._edge_attrs.values()
for attr in edge_attrs:
attr.size = self[edge_type].size()
added_attrs.add((attr.edge_type, attr.layout))
out.extend(edge_attrs)

# Check edges added through regular interface:
# TODO deprecate this and store edge attributes for all edges in
# EdgeStorage
for edge_type, edge_store in self.edge_items():
for layout, attr_name in EDGE_LAYOUT_TO_ATTR_NAME.items():
# Don't double count:
if attr_name in edge_store and ((edge_type, layout)
not in added_attrs):
out.append(
EdgeAttr(edge_type=edge_type, layout=layout,
size=self[edge_type].size()))

return out


Expand Down
31 changes: 24 additions & 7 deletions torch_geometric/loader/neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,21 @@ def __init__(
# TODO support `collect` on `FeatureStore`
self.node_time_dict = None
if time_attr is not None:
raise ValueError(
f"'time_attr' attribute not yet supported for "
f"'{data[0].__class__.__name__}' object")
# We need to obtain all features with 'attr_name=time_attr'
# from the feature store and store them in node_time_dict. To
# do so, we make an explicit feature store GET call here with
# the relevant 'TensorAttr's
time_attrs = [
attr for attr in feature_store.get_all_tensor_attrs()
if attr.attr_name == time_attr
]
for attr in time_attrs:
attr.index = None
time_tensors = feature_store.multi_get_tensor(time_attrs)
self.node_time_dict = {
time_attr.group_name: time_tensor
for time_attr, time_tensor in zip(time_attrs, time_tensors)
}

# Obtain all node and edge metadata:
node_attrs = feature_store.get_all_tensor_attrs()
Expand Down Expand Up @@ -475,18 +487,23 @@ def to_index(tensor):
if isinstance(input_nodes, Tensor):
return None, to_index(input_nodes)

# Can't infer number of nodes from a group_name; need an attr_name
if isinstance(input_nodes, str):
num_nodes = feature_store.get_tensor_size(input_nodes)[0]
return input_nodes, range(num_nodes)
raise NotImplementedError(
f"Cannot infer the number of nodes from a single string "
f"(got '{input_nodes}'). Please pass a more explicit "
f"representation. ")

if isinstance(input_nodes, (list, tuple)):
assert len(input_nodes) == 2
assert isinstance(input_nodes[0], str)

node_type, input_nodes = input_nodes
if input_nodes is None:
num_nodes = feature_store.get_tensor_size(input_nodes)[0]
return input_nodes[0], range(num_nodes)
raise NotImplementedError(
f"Cannot infer the number of nodes from a node type alone "
f"(got '{input_nodes}'). Please pass a more explicit "
f"representation. ")
return node_type, to_index(input_nodes)

assert isinstance(input_nodes, TensorAttr)
Expand Down

0 comments on commit db5e6d9

Please sign in to comment.