diff --git a/CHANGELOG.md b/CHANGELOG.md index a349416fc941..67e5b2bfd7fe 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [2.0.5] - 2022-MM-DD ### Added +- Added `FeatureStore` support to `Data` and `HeteroData` ([#4807](https://github.com/pyg-team/pytorch_geometric/pull/4807)) - Added support for dense aggregations in `global_*_pool` ([#4827](https://github.com/pyg-team/pytorch_geometric/pull/4827)) - Added Python version requirement ([#4825](https://github.com/pyg-team/pytorch_geometric/pull/4825)) - Added TorchScript support to `JumpingKnowledge` module ([#4805](https://github.com/pyg-team/pytorch_geometric/pull/4805)) diff --git a/test/data/test_data.py b/test/data/test_data.py index b794364be308..be06d43bed0b 100644 --- a/test/data/test_data.py +++ b/test/data/test_data.py @@ -239,3 +239,28 @@ def my_attr1(self, value): data.my_attr1 = 2 assert 'my_attr1' not in data._store assert data.my_attr1 == 2 + + +# Feature Store ############################################################### + + +def test_basic_feature_store(): + data = Data() + x = torch.randn(20, 20) + + # Put tensor: + assert data.put_tensor(copy.deepcopy(x), attr_name='x', index=None) + assert torch.equal(data.x, x) + + # Put (modify) tensor slice: + x[15:] = 0 + data.put_tensor(0, attr_name='x', index=slice(15, None, None)) + + # Get tensor: + out = data.get_tensor(attr_name='x', index=None) + assert torch.equal(x, out) + + # Remove tensor: + assert 'x' in data.__dict__['_store'] + data.remove_tensor(attr_name='x', index=None) + assert 'x' not in data.__dict__['_store'] diff --git a/test/data/test_hetero_data.py b/test/data/test_hetero_data.py index ba5f7a33f389..b26832bcb068 100644 --- a/test/data/test_hetero_data.py +++ b/test/data/test_hetero_data.py @@ -400,3 +400,30 @@ def test_hetero_data_to_canonical(): with pytest.raises(TypeError, match="missing 1 required"): data['user', 'product'] + + +# Feature Store ############################################################### + + +def test_basic_feature_store(): + data = HeteroData() + x = torch.randn(20, 20) + + # Put tensor: + assert data.put_tensor(copy.deepcopy(x), group_name='paper', attr_name='x', + index=None) + assert torch.equal(data['paper'].x, x) + + # Put (modify) tensor slice: + x[15:] = 0 + data.put_tensor(0, group_name='paper', attr_name='x', + index=slice(15, None, None)) + + # Get tensor: + out = data.get_tensor(group_name='paper', attr_name='x', index=None) + assert torch.equal(x, out) + + # Remove tensor: + assert 'x' in data['paper'].__dict__['_mapping'] + data.remove_tensor(group_name='paper', attr_name='x', index=None) + assert 'x' not in data['paper'].__dict__['_mapping'] diff --git a/torch_geometric/data/batch.py b/torch_geometric/data/batch.py index ecaae4d663b3..43e553ab1097 100644 --- a/torch_geometric/data/batch.py +++ b/torch_geometric/data/batch.py @@ -23,8 +23,16 @@ def __call__(cls, *args, **kwargs): new_cls = base_cls else: name = f'{base_cls.__name__}{cls.__name__}' + + # NOTE `MetaResolver` is necessary to resolve metaclass conflict + # problems between `DynamicInheritance` and the metaclass of + # `base_cls`. In particular, it creates a new common metaclass + # from the defined metaclasses. + class MetaResolver(type(cls), type(base_cls)): + pass + if name not in globals(): - globals()[name] = type(name, (cls, base_cls), {}) + globals()[name] = MetaResolver(name, (cls, base_cls), {}) new_cls = globals()[name] params = list(inspect.signature(base_cls.__init__).parameters.items()) diff --git a/torch_geometric/data/data.py b/torch_geometric/data/data.py index 580c9bdd3b6e..3a222246b44e 100644 --- a/torch_geometric/data/data.py +++ b/torch_geometric/data/data.py @@ -1,5 +1,6 @@ import copy from collections.abc import Mapping, Sequence +from dataclasses import dataclass from typing import ( Any, Callable, @@ -17,6 +18,12 @@ from torch import Tensor from torch_sparse import SparseTensor +from torch_geometric.data.feature_store import ( + FeatureStore, + FeatureTensorType, + TensorAttr, + _field_status, +) from torch_geometric.data.storage import ( BaseStorage, EdgeStorage, @@ -300,7 +307,16 @@ def contains_self_loops(self) -> bool: ############################################################################### -class Data(BaseData): +@dataclass +class DataTensorAttr(TensorAttr): + r"""Attribute class for `Data`, which does not require a `group_name`.""" + def __init__(self, attr_name=_field_status.UNSET, + index=_field_status.UNSET): + # Treat group_name as optional, and move it to the end + super().__init__(None, attr_name, index) + + +class Data(BaseData, FeatureStore): r"""A data object describing a homogeneous graph. The data object can hold node-level, link-level and graph-level attributes. In general, :class:`~torch_geometric.data.Data` tries to mimic the @@ -348,7 +364,10 @@ class Data(BaseData): def __init__(self, x: OptTensor = None, edge_index: OptTensor = None, edge_attr: OptTensor = None, y: OptTensor = None, pos: OptTensor = None, **kwargs): - super().__init__() + # `Data` doesn't support group_name, so we need to adjust `TensorAttr` + # accordingly here to avoid requiring `group_name` to be set: + super().__init__(attr_cls=DataTensorAttr) + self.__dict__['_store'] = GlobalStorage(_parent=self) if x is not None: @@ -384,6 +403,9 @@ def __setattr__(self, key: str, value: Any): def __delattr__(self, key: str): delattr(self._store, key) + # TODO consider supporting the feature store interface for + # __getitem__, __setitem__, and __delitem__ so, for example, we + # can accept key: Union[str, TensorAttr] in __getitem__. def __getitem__(self, key: str) -> Any: return self._store[key] @@ -692,6 +714,47 @@ def num_faces(self) -> Optional[int]: return self.face.size(self.__cat_dim__('face', self.face)) return None + # FeatureStore interface ########################################### + + def items(self): + r"""Returns an `ItemsView` over the stored attributes in the `Data` + object.""" + return self._store.items() + + def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool: + r"""Stores a feature tensor in node storage.""" + out = getattr(self, attr.attr_name, None) + if out is not None and attr.index is not None: + # Attr name exists, handle index: + out[attr.index] = tensor + else: + # No attr name (or None index), just store tensor: + setattr(self, attr.attr_name, tensor) + return True + + def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]: + r"""Obtains a feature tensor from node storage.""" + # Retrieve tensor and index accordingly: + tensor = getattr(self, attr.attr_name, None) + if tensor is not None: + # TODO this behavior is a bit odd, since TensorAttr requires that + # we set `index`. So, we assume here that indexing by `None` is + # equivalent to not indexing at all, which is not in line with + # Python semantics. + return tensor[attr.index] if attr.index is not None else tensor + return None + + def _remove_tensor(self, attr: TensorAttr) -> bool: + r"""Deletes a feature tensor from node storage.""" + # Remove tensor entirely: + if hasattr(self, attr.attr_name): + delattr(self, attr.attr_name) + return True + return False + + def __len__(self) -> int: + return BaseData.__len__(self) + ############################################################################### diff --git a/torch_geometric/data/feature_store.py b/torch_geometric/data/feature_store.py index bc7d10322497..b9c2aa623cc6 100644 --- a/torch_geometric/data/feature_store.py +++ b/torch_geometric/data/feature_store.py @@ -245,7 +245,7 @@ def __init__(self, attr_cls: Any = TensorAttr): attributes by subclassing :class:`TensorAttr` and passing the subclass as :obj:`attr_cls`.""" super().__init__() - self._attr_cls = attr_cls + self.__dict__['_attr_cls'] = attr_cls # Core (CRUD) ############################################################# diff --git a/torch_geometric/data/hetero_data.py b/torch_geometric/data/hetero_data.py index d4e77c1a80e3..051833a36371 100644 --- a/torch_geometric/data/hetero_data.py +++ b/torch_geometric/data/hetero_data.py @@ -10,6 +10,11 @@ from torch_sparse import SparseTensor from torch_geometric.data.data import BaseData, Data, size_repr +from torch_geometric.data.feature_store import ( + FeatureStore, + FeatureTensorType, + TensorAttr, +) from torch_geometric.data.storage import BaseStorage, EdgeStorage, NodeStorage from torch_geometric.typing import EdgeType, NodeType, QueryType from torch_geometric.utils import bipartite_subgraph, is_undirected @@ -18,7 +23,7 @@ NodeOrEdgeStorage = Union[NodeStorage, EdgeStorage] -class HeteroData(BaseData): +class HeteroData(BaseData, FeatureStore): r"""A data object describing a heterogeneous graph, holding multiple node and/or edge types in disjunct storage objects. Storage objects can hold either node-level, link-level or graph-level @@ -92,6 +97,8 @@ class HeteroData(BaseData): DEFAULT_REL = 'to' def __init__(self, _mapping: Optional[Dict[str, Any]] = None, **kwargs): + super().__init__() + self.__dict__['_global_store'] = BaseStorage(_parent=self) self.__dict__['_node_store_dict'] = {} self.__dict__['_edge_store_dict'] = {} @@ -616,6 +623,52 @@ def _consistent_size(stores: List[BaseStorage]) -> List[str]: return data + # :obj:`FeatureStore` interface ########################################### + + def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool: + r"""Stores a feature tensor in node storage.""" + if not attr.is_set('index'): + attr.index = None + + out = self._node_store_dict.get(attr.group_name, None) + if out: + # Group name exists, handle index or create new attribute name: + val = getattr(out, attr.attr_name) + if val is not None: + val[attr.index] = tensor + else: + setattr(self[attr.group_name], attr.attr_name, tensor) + else: + # No node storage found, just store tensor in new one: + setattr(self[attr.group_name], attr.attr_name, tensor) + return True + + def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]: + r"""Obtains a feature tensor from node storage.""" + # Retrieve tensor and index accordingly: + tensor = getattr(self[attr.group_name], attr.attr_name, None) + if tensor is not None: + # TODO this behavior is a bit odd, since TensorAttr requires that + # we set `index`. So, we assume here that indexing by `None` is + # equivalent to not indexing at all, which is not in line with + # Python semantics. + return tensor[attr.index] if attr.index is not None else tensor + return None + + def _remove_tensor(self, attr: TensorAttr) -> bool: + r"""Deletes a feature tensor from node storage.""" + # Remove tensor entirely: + if hasattr(self[attr.group_name], attr.attr_name): + delattr(self[attr.group_name], attr.attr_name) + return True + return False + + def __len__(self) -> int: + return BaseData.__len__(self) + + def __iter__(self): + raise NotImplementedError + # Helper functions ############################################################