Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: properly filter custom feature/graph stores #5088

Merged
merged 8 commits into from
Jul 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `time_attr` argument to `LinkNeighborLoader` ([#4877](https://github.com/pyg-team/pytorch_geometric/pull/4877), [#4908](https://github.com/pyg-team/pytorch_geometric/pull/4908))
- Added a `filter_per_worker` argument to data loaders to allow filtering of data within sub-processes ([#4873](https://github.com/pyg-team/pytorch_geometric/pull/4873))
- Added a `NeighborLoader` benchmark script ([#4815](https://github.com/pyg-team/pytorch_geometric/pull/4815), [#4862](https://github.com/pyg-team/pytorch_geometric/pull/4862/files))
- Added support for `FeatureStore` and `GraphStore` in `NeighborLoader` ([#4817](https://github.com/pyg-team/pytorch_geometric/pull/4817), [#4851](https://github.com/pyg-team/pytorch_geometric/pull/4851), [#4854](https://github.com/pyg-team/pytorch_geometric/pull/4854), [#4856](https://github.com/pyg-team/pytorch_geometric/pull/4856), [#4857](https://github.com/pyg-team/pytorch_geometric/pull/4857), [#4882](https://github.com/pyg-team/pytorch_geometric/pull/4882), [#4883](https://github.com/pyg-team/pytorch_geometric/pull/4883), [#4929](https://github.com/pyg-team/pytorch_geometric/pull/4929), [#4992](https://github.com/pyg-team/pytorch_geometric/pull/4922), [#4962](https://github.com/pyg-team/pytorch_geometric/pull/4962), [#4968](https://github.com/pyg-team/pytorch_geometric/pull/4968), [#5037](https://github.com/pyg-team/pytorch_geometric/pull/5037))
- Added support for `FeatureStore` and `GraphStore` in `NeighborLoader` ([#4817](https://github.com/pyg-team/pytorch_geometric/pull/4817), [#4851](https://github.com/pyg-team/pytorch_geometric/pull/4851), [#4854](https://github.com/pyg-team/pytorch_geometric/pull/4854), [#4856](https://github.com/pyg-team/pytorch_geometric/pull/4856), [#4857](https://github.com/pyg-team/pytorch_geometric/pull/4857), [#4882](https://github.com/pyg-team/pytorch_geometric/pull/4882), [#4883](https://github.com/pyg-team/pytorch_geometric/pull/4883), [#4929](https://github.com/pyg-team/pytorch_geometric/pull/4929), [#4992](https://github.com/pyg-team/pytorch_geometric/pull/4922), [#4962](https://github.com/pyg-team/pytorch_geometric/pull/4962), [#4968](https://github.com/pyg-team/pytorch_geometric/pull/4968), [#5037](https://github.com/pyg-team/pytorch_geometric/pull/5037), [#5088](https://github.com/pyg-team/pytorch_geometric/pull/5088))
- Added a `normalize` parameter to `dense_diff_pool` ([#4847](https://github.com/pyg-team/pytorch_geometric/pull/4847))
- Added `size=None` explanation to jittable `MessagePassing` modules in the documentation ([#4850](https://github.com/pyg-team/pytorch_geometric/pull/4850))
- Added documentation to the `DataLoaderIterator` class ([#4838](https://github.com/pyg-team/pytorch_geometric/pull/4838))
Expand Down
8 changes: 4 additions & 4 deletions torch_geometric/loader/link_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from torch_geometric.loader.base import DataLoaderIterator
from torch_geometric.loader.neighbor_loader import NeighborSampler
from torch_geometric.loader.utils import (
filter_custom_store,
filter_data,
filter_feature_store,
filter_hetero_data,
)
from torch_geometric.typing import InputEdges, NumNeighbors, OptTensor
Expand Down Expand Up @@ -339,9 +339,9 @@ def filter_fn(self, out: Any) -> Union[Data, HeteroData]:
else:
(node_dict, row_dict, col_dict, edge_dict, edge_label_index,
edge_label) = out
feature_store, _ = self.data
data = filter_feature_store(feature_store, node_dict, row_dict,
col_dict, edge_dict)
feature_store, graph_store = self.data
data = filter_custom_store(feature_store, graph_store, node_dict,
row_dict, col_dict, edge_dict)
edge_type = self.neighbor_sampler.input_type
data[edge_type].edge_label_index = edge_label_index
if edge_label is not None:
Expand Down
8 changes: 4 additions & 4 deletions torch_geometric/loader/neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from torch_geometric.loader.base import DataLoaderIterator
from torch_geometric.loader.utils import (
edge_type_to_str,
filter_custom_store,
filter_data,
filter_feature_store,
filter_hetero_data,
to_csc,
to_hetero_csc,
Expand Down Expand Up @@ -415,9 +415,9 @@ def filter_fn(self, out: Any) -> Union[Data, HeteroData]:
else: # Tuple[FeatureStore, GraphStore]
# TODO support for feature stores with no edge types
node_dict, row_dict, col_dict, edge_dict, batch_size = out
feature_store, _ = self.data
data = filter_feature_store(feature_store, node_dict, row_dict,
col_dict, edge_dict)
feature_store, graph_store = self.data
data = filter_custom_store(feature_store, graph_store, node_dict,
row_dict, col_dict, edge_dict)
data[self.neighbor_sampler.input_type].batch_size = batch_size

return data if self.transform is None else self.transform(data)
Expand Down
19 changes: 9 additions & 10 deletions torch_geometric/loader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from torch_geometric.data import Data, HeteroData
from torch_geometric.data.feature_store import FeatureStore
from torch_geometric.data.graph_store import GraphStore
from torch_geometric.data.storage import EdgeStorage, NodeStorage
from torch_geometric.typing import EdgeType, OptTensor

Expand All @@ -31,10 +32,6 @@ def edge_type_to_str(edge_type: Union[EdgeType, str]) -> str:
return edge_type if isinstance(edge_type, str) else '__'.join(edge_type)


def str_to_edge_type(key: Union[EdgeType, str]) -> EdgeType:
return key if isinstance(key, tuple) else tuple(key.split('__'))


# TODO deprecate when FeatureStore / GraphStore unification is complete
def to_csc(
data: Union[Data, EdgeStorage],
Expand Down Expand Up @@ -188,8 +185,9 @@ def filter_hetero_data(
return out


def filter_feature_store(
def filter_custom_store(
feature_store: FeatureStore,
graph_store: GraphStore,
node_dict: Dict[str, Tensor],
row_dict: Dict[str, Tensor],
col_dict: Dict[str, Tensor],
Expand All @@ -204,14 +202,15 @@ def filter_feature_store(

# Filter edge storage:
# TODO support edge attributes
for key in edge_dict:
edge_index = torch.stack([row_dict[key], col_dict[key]], dim=0)
data[str_to_edge_type(key)].edge_index = edge_index
for attr in graph_store.get_all_edge_attrs():
key = edge_type_to_str(attr.edge_type)
if key in row_dict and key in col_dict:
edge_index = torch.stack([row_dict[key], col_dict[key]], dim=0)
data[attr.edge_type].edge_index = edge_index

# Filter node storage:
attrs = feature_store.get_all_tensor_attrs()
required_attrs = []
for attr in attrs:
for attr in feature_store.get_all_tensor_attrs():
if attr.group_name in node_dict:
attr.index = node_dict[attr.group_name]
required_attrs.append(attr)
Expand Down