Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NeighborLoader: support temporal sampling with (FeatureStore, GraphStore) #4929

Merged
merged 5 commits into from
Jul 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `time_attr` argument to `LinkNeighborLoader` ([#4877](https://github.com/pyg-team/pytorch_geometric/pull/4877), [#4908](https://github.com/pyg-team/pytorch_geometric/pull/4908))
- Added a `filter_per_worker` argument to data loaders to allow filtering of data within sub-processes ([#4873](https://github.com/pyg-team/pytorch_geometric/pull/4873))
- Added a `NeighborLoader` benchmark script ([#4815](https://github.com/pyg-team/pytorch_geometric/pull/4815))
- Added support for `FeatureStore` and `GraphStore` in `NeighborLoader` ([#4817](https://github.com/pyg-team/pytorch_geometric/pull/4817), [#4851](https://github.com/pyg-team/pytorch_geometric/pull/4851), [#4854](https://github.com/pyg-team/pytorch_geometric/pull/4854), [#4856](https://github.com/pyg-team/pytorch_geometric/pull/4856), [#4857](https://github.com/pyg-team/pytorch_geometric/pull/4857), [#4882](https://github.com/pyg-team/pytorch_geometric/pull/4882), [#4883](https://github.com/pyg-team/pytorch_geometric/pull/4883), [#4992](https://github.com/pyg-team/pytorch_geometric/pull/4922))
- Added support for `FeatureStore` and `GraphStore` in `NeighborLoader` ([#4817](https://github.com/pyg-team/pytorch_geometric/pull/4817), [#4851](https://github.com/pyg-team/pytorch_geometric/pull/4851), [#4854](https://github.com/pyg-team/pytorch_geometric/pull/4854), [#4856](https://github.com/pyg-team/pytorch_geometric/pull/4856), [#4857](https://github.com/pyg-team/pytorch_geometric/pull/4857), [#4882](https://github.com/pyg-team/pytorch_geometric/pull/4882), [#4883](https://github.com/pyg-team/pytorch_geometric/pull/4883), [#4929](https://github.com/pyg-team/pytorch_geometric/pull/4929), [#4992](https://github.com/pyg-team/pytorch_geometric/pull/4922))
- Added a `normalize` parameter to `dense_diff_pool` ([#4847](https://github.com/pyg-team/pytorch_geometric/pull/4847))
- Added `size=None` explanation to jittable `MessagePassing` modules in the documentation ([#4850](https://github.com/pyg-team/pytorch_geometric/pull/4850))
- Added documentation to the `DataLoaderIterator` class ([#4838](https://github.com/pyg-team/pytorch_geometric/pull/4838))
Expand Down
45 changes: 45 additions & 0 deletions test/loader/test_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch_sparse import SparseTensor

from torch_geometric.data import Data, HeteroData
from torch_geometric.data.feature_store import TensorAttr
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import GraphConv, to_hetero
from torch_geometric.testing import withRegisteredOp
Expand Down Expand Up @@ -359,3 +360,47 @@ def test_custom_neighbor_loader(FeatureStore, GraphStore):
'paper', 'to', 'author'].edge_index.size())
assert (batch1['author', 'to', 'paper'].edge_index.size() == batch1[
'author', 'to', 'paper'].edge_index.size())


@withRegisteredOp('torch_sparse.hetero_temporal_neighbor_sample')
@pytest.mark.parametrize('FeatureStore', [MyFeatureStore, HeteroData])
@pytest.mark.parametrize('GraphStore', [MyGraphStore, HeteroData])
def test_temporal_custom_neighbor_loader_on_cora(get_dataset, FeatureStore,
GraphStore):
# Initialize dataset (once):
dataset = get_dataset(name='Cora')
data = dataset[0]

# Initialize feature store, graph store, and reference:
feature_store = FeatureStore()
graph_store = GraphStore()
hetero_data = HeteroData()

feature_store.put_tensor(data.x, group_name='paper', attr_name='x',
index=None)
hetero_data['paper'].x = data.x

feature_store.put_tensor(torch.arange(data.num_nodes), group_name='paper',
attr_name='time', index=None)
hetero_data['paper'].time = torch.arange(data.num_nodes)

num_nodes = data.x.size(dim=0)
graph_store.put_edge_index(edge_index=data.edge_index,
edge_type=('paper', 'to', 'paper'),
layout='coo', size=(num_nodes, num_nodes))
hetero_data['paper', 'to', 'paper'].edge_index = data.edge_index

loader1 = NeighborLoader(hetero_data, num_neighbors=[-1, -1],
input_nodes='paper', time_attr='time',
batch_size=128)

loader2 = NeighborLoader(
(feature_store, graph_store),
num_neighbors=[-1, -1],
input_nodes=TensorAttr(group_name='paper', attr_name='x'),
time_attr='time',
batch_size=128,
)

for batch1, batch2 in zip(loader1, loader2):
assert torch.equal(batch1['paper'].time, batch2['paper'].time)
13 changes: 13 additions & 0 deletions torch_geometric/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,10 +874,23 @@ def get_all_edge_attrs(self) -> List[EdgeAttr]:
in `Data` and their layouts"""
if not hasattr(self, '_edge_attrs'):
return []
added_attrs = set()

# Check edges added via _put_edge_index:
edge_attrs = self._edge_attrs.values()
for attr in edge_attrs:
attr.size = (self.num_nodes, self.num_nodes)
added_attrs.add(attr.layout)

# Check edges added through regular interface:
# TODO deprecate this and store edge attributes for all edges in
# EdgeStorage
for layout, attr_name in EDGE_LAYOUT_TO_ATTR_NAME.items():
if attr_name in self and layout not in added_attrs:
edge_attrs.append(
EdgeAttr(edge_type=None, layout=layout,
size=(self.num_nodes, self.num_nodes)))

return edge_attrs


Expand Down
21 changes: 19 additions & 2 deletions torch_geometric/data/hetero_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,7 @@ def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool:
out = self._node_store_dict.get(attr.group_name, None)
if out:
# Group name exists, handle index or create new attribute name:
val = getattr(out, attr.attr_name)
val = getattr(out, attr.attr_name, None)
if val is not None:
val[attr.index] = tensor
else:
Expand Down Expand Up @@ -787,13 +787,30 @@ def get_all_edge_attrs(self) -> List[EdgeAttr]:
r"""Returns a list of `EdgeAttr` objects corresponding to the edge
indices stored in `HeteroData` and their layouts."""
out = []
for edge_type, edge_store in self.edge_items():
added_attrs = set()

# Check edges added via _put_edge_index:
for edge_type, _ in self.edge_items():
if not hasattr(self[edge_type], '_edge_attrs'):
continue
edge_attrs = self[edge_type]._edge_attrs.values()
for attr in edge_attrs:
attr.size = self[edge_type].size()
added_attrs.add((attr.edge_type, attr.layout))
out.extend(edge_attrs)

# Check edges added through regular interface:
# TODO deprecate this and store edge attributes for all edges in
# EdgeStorage
for edge_type, edge_store in self.edge_items():
for layout, attr_name in EDGE_LAYOUT_TO_ATTR_NAME.items():
# Don't double count:
if attr_name in edge_store and ((edge_type, layout)
not in added_attrs):
out.append(
EdgeAttr(edge_type=edge_type, layout=layout,
size=self[edge_type].size()))

return out


Expand Down
31 changes: 24 additions & 7 deletions torch_geometric/loader/neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,21 @@ def __init__(
# TODO support `collect` on `FeatureStore`
self.node_time_dict = None
if time_attr is not None:
raise ValueError(
f"'time_attr' attribute not yet supported for "
f"'{data[0].__class__.__name__}' object")
# We need to obtain all features with 'attr_name=time_attr'
# from the feature store and store them in node_time_dict. To
# do so, we make an explicit feature store GET call here with
# the relevant 'TensorAttr's
time_attrs = [
attr for attr in feature_store.get_all_tensor_attrs()
if attr.attr_name == time_attr
]
for attr in time_attrs:
attr.index = None
time_tensors = feature_store.multi_get_tensor(time_attrs)
self.node_time_dict = {
mananshah99 marked this conversation as resolved.
Show resolved Hide resolved
time_attr.group_name: time_tensor
for time_attr, time_tensor in zip(time_attrs, time_tensors)
}

# Obtain all node and edge metadata:
node_attrs = feature_store.get_all_tensor_attrs()
Expand Down Expand Up @@ -475,18 +487,23 @@ def to_index(tensor):
if isinstance(input_nodes, Tensor):
return None, to_index(input_nodes)

# Can't infer number of nodes from a group_name; need an attr_name
if isinstance(input_nodes, str):
num_nodes = feature_store.get_tensor_size(input_nodes)[0]
return input_nodes, range(num_nodes)
raise NotImplementedError(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we raise here instead of fix it? Shouldn't feature_store.get_tensor_size(TensorAttr(group_name=input_nodes))[0] fix this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That get tensor size call wouldn't work as intended since get tensor rise needs both a group name and are name, hence the raise

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So what we could do is to select the first TensorAttr that has group_name==input_nodes, right? But yeah, not ideal.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, not a huge fan of that solution; will leave for a later PR if we have to go that route.

f"Cannot infer the number of nodes from a single string "
f"(got '{input_nodes}'). Please pass a more explicit "
f"representation. ")

if isinstance(input_nodes, (list, tuple)):
assert len(input_nodes) == 2
assert isinstance(input_nodes[0], str)

node_type, input_nodes = input_nodes
if input_nodes is None:
num_nodes = feature_store.get_tensor_size(input_nodes)[0]
return input_nodes[0], range(num_nodes)
raise NotImplementedError(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feature_store.get_tensor_size(TensorAttr(group_name=None))[0]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same reason as above.

f"Cannot infer the number of nodes from a node type alone "
f"(got '{input_nodes}'). Please pass a more explicit "
f"representation. ")
return node_type, to_index(input_nodes)

assert isinstance(input_nodes, TensorAttr)
Expand Down