Skip to content

Commit

Permalink
Fix: properly filter custom feature/graph stores (#5088)
Browse files Browse the repository at this point in the history
* init

* changelog

* typo

* typo

* update

Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
mananshah99 and rusty1s authored Jul 30, 2022
1 parent 2a16256 commit 58df978
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 19 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `time_attr` argument to `LinkNeighborLoader` ([#4877](https://github.com/pyg-team/pytorch_geometric/pull/4877), [#4908](https://github.com/pyg-team/pytorch_geometric/pull/4908))
- Added a `filter_per_worker` argument to data loaders to allow filtering of data within sub-processes ([#4873](https://github.com/pyg-team/pytorch_geometric/pull/4873))
- Added a `NeighborLoader` benchmark script ([#4815](https://github.com/pyg-team/pytorch_geometric/pull/4815), [#4862](https://github.com/pyg-team/pytorch_geometric/pull/4862/files))
- Added support for `FeatureStore` and `GraphStore` in `NeighborLoader` ([#4817](https://github.com/pyg-team/pytorch_geometric/pull/4817), [#4851](https://github.com/pyg-team/pytorch_geometric/pull/4851), [#4854](https://github.com/pyg-team/pytorch_geometric/pull/4854), [#4856](https://github.com/pyg-team/pytorch_geometric/pull/4856), [#4857](https://github.com/pyg-team/pytorch_geometric/pull/4857), [#4882](https://github.com/pyg-team/pytorch_geometric/pull/4882), [#4883](https://github.com/pyg-team/pytorch_geometric/pull/4883), [#4929](https://github.com/pyg-team/pytorch_geometric/pull/4929), [#4992](https://github.com/pyg-team/pytorch_geometric/pull/4922), [#4962](https://github.com/pyg-team/pytorch_geometric/pull/4962), [#4968](https://github.com/pyg-team/pytorch_geometric/pull/4968), [#5037](https://github.com/pyg-team/pytorch_geometric/pull/5037))
- Added support for `FeatureStore` and `GraphStore` in `NeighborLoader` ([#4817](https://github.com/pyg-team/pytorch_geometric/pull/4817), [#4851](https://github.com/pyg-team/pytorch_geometric/pull/4851), [#4854](https://github.com/pyg-team/pytorch_geometric/pull/4854), [#4856](https://github.com/pyg-team/pytorch_geometric/pull/4856), [#4857](https://github.com/pyg-team/pytorch_geometric/pull/4857), [#4882](https://github.com/pyg-team/pytorch_geometric/pull/4882), [#4883](https://github.com/pyg-team/pytorch_geometric/pull/4883), [#4929](https://github.com/pyg-team/pytorch_geometric/pull/4929), [#4992](https://github.com/pyg-team/pytorch_geometric/pull/4922), [#4962](https://github.com/pyg-team/pytorch_geometric/pull/4962), [#4968](https://github.com/pyg-team/pytorch_geometric/pull/4968), [#5037](https://github.com/pyg-team/pytorch_geometric/pull/5037), [#5088](https://github.com/pyg-team/pytorch_geometric/pull/5088))
- Added a `normalize` parameter to `dense_diff_pool` ([#4847](https://github.com/pyg-team/pytorch_geometric/pull/4847))
- Added `size=None` explanation to jittable `MessagePassing` modules in the documentation ([#4850](https://github.com/pyg-team/pytorch_geometric/pull/4850))
- Added documentation to the `DataLoaderIterator` class ([#4838](https://github.com/pyg-team/pytorch_geometric/pull/4838))
Expand Down
8 changes: 4 additions & 4 deletions torch_geometric/loader/link_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from torch_geometric.loader.base import DataLoaderIterator
from torch_geometric.loader.neighbor_loader import NeighborSampler
from torch_geometric.loader.utils import (
filter_custom_store,
filter_data,
filter_feature_store,
filter_hetero_data,
)
from torch_geometric.typing import InputEdges, NumNeighbors, OptTensor
Expand Down Expand Up @@ -339,9 +339,9 @@ def filter_fn(self, out: Any) -> Union[Data, HeteroData]:
else:
(node_dict, row_dict, col_dict, edge_dict, edge_label_index,
edge_label) = out
feature_store, _ = self.data
data = filter_feature_store(feature_store, node_dict, row_dict,
col_dict, edge_dict)
feature_store, graph_store = self.data
data = filter_custom_store(feature_store, graph_store, node_dict,
row_dict, col_dict, edge_dict)
edge_type = self.neighbor_sampler.input_type
data[edge_type].edge_label_index = edge_label_index
if edge_label is not None:
Expand Down
8 changes: 4 additions & 4 deletions torch_geometric/loader/neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from torch_geometric.loader.base import DataLoaderIterator
from torch_geometric.loader.utils import (
edge_type_to_str,
filter_custom_store,
filter_data,
filter_feature_store,
filter_hetero_data,
to_csc,
to_hetero_csc,
Expand Down Expand Up @@ -415,9 +415,9 @@ def filter_fn(self, out: Any) -> Union[Data, HeteroData]:
else: # Tuple[FeatureStore, GraphStore]
# TODO support for feature stores with no edge types
node_dict, row_dict, col_dict, edge_dict, batch_size = out
feature_store, _ = self.data
data = filter_feature_store(feature_store, node_dict, row_dict,
col_dict, edge_dict)
feature_store, graph_store = self.data
data = filter_custom_store(feature_store, graph_store, node_dict,
row_dict, col_dict, edge_dict)
data[self.neighbor_sampler.input_type].batch_size = batch_size

return data if self.transform is None else self.transform(data)
Expand Down
19 changes: 9 additions & 10 deletions torch_geometric/loader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from torch_geometric.data import Data, HeteroData
from torch_geometric.data.feature_store import FeatureStore
from torch_geometric.data.graph_store import GraphStore
from torch_geometric.data.storage import EdgeStorage, NodeStorage
from torch_geometric.typing import EdgeType, OptTensor

Expand All @@ -31,10 +32,6 @@ def edge_type_to_str(edge_type: Union[EdgeType, str]) -> str:
return edge_type if isinstance(edge_type, str) else '__'.join(edge_type)


def str_to_edge_type(key: Union[EdgeType, str]) -> EdgeType:
return key if isinstance(key, tuple) else tuple(key.split('__'))


# TODO deprecate when FeatureStore / GraphStore unification is complete
def to_csc(
data: Union[Data, EdgeStorage],
Expand Down Expand Up @@ -188,8 +185,9 @@ def filter_hetero_data(
return out


def filter_feature_store(
def filter_custom_store(
feature_store: FeatureStore,
graph_store: GraphStore,
node_dict: Dict[str, Tensor],
row_dict: Dict[str, Tensor],
col_dict: Dict[str, Tensor],
Expand All @@ -204,14 +202,15 @@ def filter_feature_store(

# Filter edge storage:
# TODO support edge attributes
for key in edge_dict:
edge_index = torch.stack([row_dict[key], col_dict[key]], dim=0)
data[str_to_edge_type(key)].edge_index = edge_index
for attr in graph_store.get_all_edge_attrs():
key = edge_type_to_str(attr.edge_type)
if key in row_dict and key in col_dict:
edge_index = torch.stack([row_dict[key], col_dict[key]], dim=0)
data[attr.edge_type].edge_index = edge_index

# Filter node storage:
attrs = feature_store.get_all_tensor_attrs()
required_attrs = []
for attr in attrs:
for attr in feature_store.get_all_tensor_attrs():
if attr.group_name in node_dict:
attr.index = node_dict[attr.group_name]
required_attrs.append(attr)
Expand Down

0 comments on commit 58df978

Please sign in to comment.