From 58df97840d46ff4da3b8dda0d978554fc2dae4b8 Mon Sep 17 00:00:00 2001 From: Manan Shah Date: Sat, 30 Jul 2022 14:10:44 -0700 Subject: [PATCH] Fix: properly filter custom feature/graph stores (#5088) * init * changelog * typo * typo * update Co-authored-by: rusty1s --- CHANGELOG.md | 2 +- .../loader/link_neighbor_loader.py | 8 ++++---- torch_geometric/loader/neighbor_loader.py | 8 ++++---- torch_geometric/loader/utils.py | 19 +++++++++---------- 4 files changed, 18 insertions(+), 19 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 62f212237e6f..2ba47aeda5d7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,7 +29,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `time_attr` argument to `LinkNeighborLoader` ([#4877](https://github.com/pyg-team/pytorch_geometric/pull/4877), [#4908](https://github.com/pyg-team/pytorch_geometric/pull/4908)) - Added a `filter_per_worker` argument to data loaders to allow filtering of data within sub-processes ([#4873](https://github.com/pyg-team/pytorch_geometric/pull/4873)) - Added a `NeighborLoader` benchmark script ([#4815](https://github.com/pyg-team/pytorch_geometric/pull/4815), [#4862](https://github.com/pyg-team/pytorch_geometric/pull/4862/files)) -- Added support for `FeatureStore` and `GraphStore` in `NeighborLoader` ([#4817](https://github.com/pyg-team/pytorch_geometric/pull/4817), [#4851](https://github.com/pyg-team/pytorch_geometric/pull/4851), [#4854](https://github.com/pyg-team/pytorch_geometric/pull/4854), [#4856](https://github.com/pyg-team/pytorch_geometric/pull/4856), [#4857](https://github.com/pyg-team/pytorch_geometric/pull/4857), [#4882](https://github.com/pyg-team/pytorch_geometric/pull/4882), [#4883](https://github.com/pyg-team/pytorch_geometric/pull/4883), [#4929](https://github.com/pyg-team/pytorch_geometric/pull/4929), [#4992](https://github.com/pyg-team/pytorch_geometric/pull/4922), [#4962](https://github.com/pyg-team/pytorch_geometric/pull/4962), [#4968](https://github.com/pyg-team/pytorch_geometric/pull/4968), [#5037](https://github.com/pyg-team/pytorch_geometric/pull/5037)) +- Added support for `FeatureStore` and `GraphStore` in `NeighborLoader` ([#4817](https://github.com/pyg-team/pytorch_geometric/pull/4817), [#4851](https://github.com/pyg-team/pytorch_geometric/pull/4851), [#4854](https://github.com/pyg-team/pytorch_geometric/pull/4854), [#4856](https://github.com/pyg-team/pytorch_geometric/pull/4856), [#4857](https://github.com/pyg-team/pytorch_geometric/pull/4857), [#4882](https://github.com/pyg-team/pytorch_geometric/pull/4882), [#4883](https://github.com/pyg-team/pytorch_geometric/pull/4883), [#4929](https://github.com/pyg-team/pytorch_geometric/pull/4929), [#4992](https://github.com/pyg-team/pytorch_geometric/pull/4922), [#4962](https://github.com/pyg-team/pytorch_geometric/pull/4962), [#4968](https://github.com/pyg-team/pytorch_geometric/pull/4968), [#5037](https://github.com/pyg-team/pytorch_geometric/pull/5037), [#5088](https://github.com/pyg-team/pytorch_geometric/pull/5088)) - Added a `normalize` parameter to `dense_diff_pool` ([#4847](https://github.com/pyg-team/pytorch_geometric/pull/4847)) - Added `size=None` explanation to jittable `MessagePassing` modules in the documentation ([#4850](https://github.com/pyg-team/pytorch_geometric/pull/4850)) - Added documentation to the `DataLoaderIterator` class ([#4838](https://github.com/pyg-team/pytorch_geometric/pull/4838)) diff --git a/torch_geometric/loader/link_neighbor_loader.py b/torch_geometric/loader/link_neighbor_loader.py index 95631a771675..25400c7b593f 100644 --- a/torch_geometric/loader/link_neighbor_loader.py +++ b/torch_geometric/loader/link_neighbor_loader.py @@ -9,8 +9,8 @@ from torch_geometric.loader.base import DataLoaderIterator from torch_geometric.loader.neighbor_loader import NeighborSampler from torch_geometric.loader.utils import ( + filter_custom_store, filter_data, - filter_feature_store, filter_hetero_data, ) from torch_geometric.typing import InputEdges, NumNeighbors, OptTensor @@ -339,9 +339,9 @@ def filter_fn(self, out: Any) -> Union[Data, HeteroData]: else: (node_dict, row_dict, col_dict, edge_dict, edge_label_index, edge_label) = out - feature_store, _ = self.data - data = filter_feature_store(feature_store, node_dict, row_dict, - col_dict, edge_dict) + feature_store, graph_store = self.data + data = filter_custom_store(feature_store, graph_store, node_dict, + row_dict, col_dict, edge_dict) edge_type = self.neighbor_sampler.input_type data[edge_type].edge_label_index = edge_label_index if edge_label is not None: diff --git a/torch_geometric/loader/neighbor_loader.py b/torch_geometric/loader/neighbor_loader.py index cc9248e6dd6a..f6cd694e1c76 100644 --- a/torch_geometric/loader/neighbor_loader.py +++ b/torch_geometric/loader/neighbor_loader.py @@ -10,8 +10,8 @@ from torch_geometric.loader.base import DataLoaderIterator from torch_geometric.loader.utils import ( edge_type_to_str, + filter_custom_store, filter_data, - filter_feature_store, filter_hetero_data, to_csc, to_hetero_csc, @@ -415,9 +415,9 @@ def filter_fn(self, out: Any) -> Union[Data, HeteroData]: else: # Tuple[FeatureStore, GraphStore] # TODO support for feature stores with no edge types node_dict, row_dict, col_dict, edge_dict, batch_size = out - feature_store, _ = self.data - data = filter_feature_store(feature_store, node_dict, row_dict, - col_dict, edge_dict) + feature_store, graph_store = self.data + data = filter_custom_store(feature_store, graph_store, node_dict, + row_dict, col_dict, edge_dict) data[self.neighbor_sampler.input_type].batch_size = batch_size return data if self.transform is None else self.transform(data) diff --git a/torch_geometric/loader/utils.py b/torch_geometric/loader/utils.py index 0658d7f94eab..2b34057e86c7 100644 --- a/torch_geometric/loader/utils.py +++ b/torch_geometric/loader/utils.py @@ -8,6 +8,7 @@ from torch_geometric.data import Data, HeteroData from torch_geometric.data.feature_store import FeatureStore +from torch_geometric.data.graph_store import GraphStore from torch_geometric.data.storage import EdgeStorage, NodeStorage from torch_geometric.typing import EdgeType, OptTensor @@ -31,10 +32,6 @@ def edge_type_to_str(edge_type: Union[EdgeType, str]) -> str: return edge_type if isinstance(edge_type, str) else '__'.join(edge_type) -def str_to_edge_type(key: Union[EdgeType, str]) -> EdgeType: - return key if isinstance(key, tuple) else tuple(key.split('__')) - - # TODO deprecate when FeatureStore / GraphStore unification is complete def to_csc( data: Union[Data, EdgeStorage], @@ -188,8 +185,9 @@ def filter_hetero_data( return out -def filter_feature_store( +def filter_custom_store( feature_store: FeatureStore, + graph_store: GraphStore, node_dict: Dict[str, Tensor], row_dict: Dict[str, Tensor], col_dict: Dict[str, Tensor], @@ -204,14 +202,15 @@ def filter_feature_store( # Filter edge storage: # TODO support edge attributes - for key in edge_dict: - edge_index = torch.stack([row_dict[key], col_dict[key]], dim=0) - data[str_to_edge_type(key)].edge_index = edge_index + for attr in graph_store.get_all_edge_attrs(): + key = edge_type_to_str(attr.edge_type) + if key in row_dict and key in col_dict: + edge_index = torch.stack([row_dict[key], col_dict[key]], dim=0) + data[attr.edge_type].edge_index = edge_index # Filter node storage: - attrs = feature_store.get_all_tensor_attrs() required_attrs = [] - for attr in attrs: + for attr in feature_store.get_all_tensor_attrs(): if attr.group_name in node_dict: attr.index = node_dict[attr.group_name] required_attrs.append(attr)