diff --git a/CHANGELOG.md b/CHANGELOG.md index 0a8a4dbc6d33..e03d3bd0a2be 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,7 +14,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `time_attr` argument to `LinkNeighborLoader` ([#4877](https://github.com/pyg-team/pytorch_geometric/pull/4877), [#4908](https://github.com/pyg-team/pytorch_geometric/pull/4908)) - Added a `filter_per_worker` argument to data loaders to allow filtering of data within sub-processes ([#4873](https://github.com/pyg-team/pytorch_geometric/pull/4873)) - Added a `NeighborLoader` benchmark script ([#4815](https://github.com/pyg-team/pytorch_geometric/pull/4815)) -- Added support for `FeatureStore` and `GraphStore` in `NeighborLoader` ([#4817](https://github.com/pyg-team/pytorch_geometric/pull/4817), [#4851](https://github.com/pyg-team/pytorch_geometric/pull/4851), [#4854](https://github.com/pyg-team/pytorch_geometric/pull/4854), [#4856](https://github.com/pyg-team/pytorch_geometric/pull/4856), [#4857](https://github.com/pyg-team/pytorch_geometric/pull/4857), [#4882](https://github.com/pyg-team/pytorch_geometric/pull/4882), [#4883](https://github.com/pyg-team/pytorch_geometric/pull/4883), [#4992](https://github.com/pyg-team/pytorch_geometric/pull/4922)) +- Added support for `FeatureStore` and `GraphStore` in `NeighborLoader` ([#4817](https://github.com/pyg-team/pytorch_geometric/pull/4817), [#4851](https://github.com/pyg-team/pytorch_geometric/pull/4851), [#4854](https://github.com/pyg-team/pytorch_geometric/pull/4854), [#4856](https://github.com/pyg-team/pytorch_geometric/pull/4856), [#4857](https://github.com/pyg-team/pytorch_geometric/pull/4857), [#4882](https://github.com/pyg-team/pytorch_geometric/pull/4882), [#4883](https://github.com/pyg-team/pytorch_geometric/pull/4883), [#4929](https://github.com/pyg-team/pytorch_geometric/pull/4929), [#4992](https://github.com/pyg-team/pytorch_geometric/pull/4922)) - Added a `normalize` parameter to `dense_diff_pool` ([#4847](https://github.com/pyg-team/pytorch_geometric/pull/4847)) - Added `size=None` explanation to jittable `MessagePassing` modules in the documentation ([#4850](https://github.com/pyg-team/pytorch_geometric/pull/4850)) - Added documentation to the `DataLoaderIterator` class ([#4838](https://github.com/pyg-team/pytorch_geometric/pull/4838)) diff --git a/test/loader/test_neighbor_loader.py b/test/loader/test_neighbor_loader.py index 445cfdf840ee..3afd5223f80c 100644 --- a/test/loader/test_neighbor_loader.py +++ b/test/loader/test_neighbor_loader.py @@ -4,6 +4,7 @@ from torch_sparse import SparseTensor from torch_geometric.data import Data, HeteroData +from torch_geometric.data.feature_store import TensorAttr from torch_geometric.loader import NeighborLoader from torch_geometric.nn import GraphConv, to_hetero from torch_geometric.testing import withRegisteredOp @@ -359,3 +360,47 @@ def test_custom_neighbor_loader(FeatureStore, GraphStore): 'paper', 'to', 'author'].edge_index.size()) assert (batch1['author', 'to', 'paper'].edge_index.size() == batch1[ 'author', 'to', 'paper'].edge_index.size()) + + +@withRegisteredOp('torch_sparse.hetero_temporal_neighbor_sample') +@pytest.mark.parametrize('FeatureStore', [MyFeatureStore, HeteroData]) +@pytest.mark.parametrize('GraphStore', [MyGraphStore, HeteroData]) +def test_temporal_custom_neighbor_loader_on_cora(get_dataset, FeatureStore, + GraphStore): + # Initialize dataset (once): + dataset = get_dataset(name='Cora') + data = dataset[0] + + # Initialize feature store, graph store, and reference: + feature_store = FeatureStore() + graph_store = GraphStore() + hetero_data = HeteroData() + + feature_store.put_tensor(data.x, group_name='paper', attr_name='x', + index=None) + hetero_data['paper'].x = data.x + + feature_store.put_tensor(torch.arange(data.num_nodes), group_name='paper', + attr_name='time', index=None) + hetero_data['paper'].time = torch.arange(data.num_nodes) + + num_nodes = data.x.size(dim=0) + graph_store.put_edge_index(edge_index=data.edge_index, + edge_type=('paper', 'to', 'paper'), + layout='coo', size=(num_nodes, num_nodes)) + hetero_data['paper', 'to', 'paper'].edge_index = data.edge_index + + loader1 = NeighborLoader(hetero_data, num_neighbors=[-1, -1], + input_nodes='paper', time_attr='time', + batch_size=128) + + loader2 = NeighborLoader( + (feature_store, graph_store), + num_neighbors=[-1, -1], + input_nodes=TensorAttr(group_name='paper', attr_name='x'), + time_attr='time', + batch_size=128, + ) + + for batch1, batch2 in zip(loader1, loader2): + assert torch.equal(batch1['paper'].time, batch2['paper'].time) diff --git a/torch_geometric/data/data.py b/torch_geometric/data/data.py index a2ad1eef9c1f..2c3200e4ab25 100644 --- a/torch_geometric/data/data.py +++ b/torch_geometric/data/data.py @@ -874,10 +874,23 @@ def get_all_edge_attrs(self) -> List[EdgeAttr]: in `Data` and their layouts""" if not hasattr(self, '_edge_attrs'): return [] + added_attrs = set() + # Check edges added via _put_edge_index: edge_attrs = self._edge_attrs.values() for attr in edge_attrs: attr.size = (self.num_nodes, self.num_nodes) + added_attrs.add(attr.layout) + + # Check edges added through regular interface: + # TODO deprecate this and store edge attributes for all edges in + # EdgeStorage + for layout, attr_name in EDGE_LAYOUT_TO_ATTR_NAME.items(): + if attr_name in self and layout not in added_attrs: + edge_attrs.append( + EdgeAttr(edge_type=None, layout=layout, + size=(self.num_nodes, self.num_nodes))) + return edge_attrs diff --git a/torch_geometric/data/hetero_data.py b/torch_geometric/data/hetero_data.py index 1cdf571ae088..ebb79dbc2d18 100644 --- a/torch_geometric/data/hetero_data.py +++ b/torch_geometric/data/hetero_data.py @@ -695,7 +695,7 @@ def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool: out = self._node_store_dict.get(attr.group_name, None) if out: # Group name exists, handle index or create new attribute name: - val = getattr(out, attr.attr_name) + val = getattr(out, attr.attr_name, None) if val is not None: val[attr.index] = tensor else: @@ -787,13 +787,30 @@ def get_all_edge_attrs(self) -> List[EdgeAttr]: r"""Returns a list of `EdgeAttr` objects corresponding to the edge indices stored in `HeteroData` and their layouts.""" out = [] - for edge_type, edge_store in self.edge_items(): + added_attrs = set() + + # Check edges added via _put_edge_index: + for edge_type, _ in self.edge_items(): if not hasattr(self[edge_type], '_edge_attrs'): continue edge_attrs = self[edge_type]._edge_attrs.values() for attr in edge_attrs: attr.size = self[edge_type].size() + added_attrs.add((attr.edge_type, attr.layout)) out.extend(edge_attrs) + + # Check edges added through regular interface: + # TODO deprecate this and store edge attributes for all edges in + # EdgeStorage + for edge_type, edge_store in self.edge_items(): + for layout, attr_name in EDGE_LAYOUT_TO_ATTR_NAME.items(): + # Don't double count: + if attr_name in edge_store and ((edge_type, layout) + not in added_attrs): + out.append( + EdgeAttr(edge_type=edge_type, layout=layout, + size=self[edge_type].size())) + return out diff --git a/torch_geometric/loader/neighbor_loader.py b/torch_geometric/loader/neighbor_loader.py index ff3c0e7b9cfa..cdf89f97b473 100644 --- a/torch_geometric/loader/neighbor_loader.py +++ b/torch_geometric/loader/neighbor_loader.py @@ -95,9 +95,21 @@ def __init__( # TODO support `collect` on `FeatureStore` self.node_time_dict = None if time_attr is not None: - raise ValueError( - f"'time_attr' attribute not yet supported for " - f"'{data[0].__class__.__name__}' object") + # We need to obtain all features with 'attr_name=time_attr' + # from the feature store and store them in node_time_dict. To + # do so, we make an explicit feature store GET call here with + # the relevant 'TensorAttr's + time_attrs = [ + attr for attr in feature_store.get_all_tensor_attrs() + if attr.attr_name == time_attr + ] + for attr in time_attrs: + attr.index = None + time_tensors = feature_store.multi_get_tensor(time_attrs) + self.node_time_dict = { + time_attr.group_name: time_tensor + for time_attr, time_tensor in zip(time_attrs, time_tensors) + } # Obtain all node and edge metadata: node_attrs = feature_store.get_all_tensor_attrs() @@ -475,9 +487,12 @@ def to_index(tensor): if isinstance(input_nodes, Tensor): return None, to_index(input_nodes) + # Can't infer number of nodes from a group_name; need an attr_name if isinstance(input_nodes, str): - num_nodes = feature_store.get_tensor_size(input_nodes)[0] - return input_nodes, range(num_nodes) + raise NotImplementedError( + f"Cannot infer the number of nodes from a single string " + f"(got '{input_nodes}'). Please pass a more explicit " + f"representation. ") if isinstance(input_nodes, (list, tuple)): assert len(input_nodes) == 2 @@ -485,8 +500,10 @@ def to_index(tensor): node_type, input_nodes = input_nodes if input_nodes is None: - num_nodes = feature_store.get_tensor_size(input_nodes)[0] - return input_nodes[0], range(num_nodes) + raise NotImplementedError( + f"Cannot infer the number of nodes from a node type alone " + f"(got '{input_nodes}'). Please pass a more explicit " + f"representation. ") return node_type, to_index(input_nodes) assert isinstance(input_nodes, TensorAttr)