From 3fec3c66d95e8ddddd6966c70377aab8ce21543a Mon Sep 17 00:00:00 2001 From: Manan Shah Date: Mon, 25 Apr 2022 21:03:28 +0000 Subject: [PATCH 01/18] Feature store abstraction + tests --- test/data/test_feature_store.py | 86 ++++++++++ torch_geometric/data/feature_store.py | 218 ++++++++++++++++++++++++++ torch_geometric/typing.py | 5 + torch_geometric/utils/mixin.py | 14 ++ 4 files changed, 323 insertions(+) create mode 100644 test/data/test_feature_store.py create mode 100644 torch_geometric/data/feature_store.py create mode 100644 torch_geometric/utils/mixin.py diff --git a/test/data/test_feature_store.py b/test/data/test_feature_store.py new file mode 100644 index 000000000000..91e958ef8fcf --- /dev/null +++ b/test/data/test_feature_store.py @@ -0,0 +1,86 @@ +from typing import Optional + +import torch + +from torch_geometric.data.feature_store import FeatureStore, TensorAttr +from torch_geometric.typing import TensorType + + +class MyFeatureStore(FeatureStore): + r"""A basic feature store, does NOT implement all functionality of a + fully-fledged feature store. Only works for Torch tensors.""" + def __init__(self): + super().__init__(backend='test') + self.store = {} + + @classmethod + def key(cls, attr: TensorAttr): + return (attr.tensor_type or '', attr.node_type or '', attr.graph_type + or '') + + def _put_tensor(self, tensor: TensorType, attr: TensorAttr) -> bool: + self.store[MyFeatureStore.key(attr)] = torch.cat( + (attr.index.reshape(-1, 1), tensor), dim=1) + return True + + def _get_tensor(self, attr: TensorAttr) -> Optional[TensorType]: + tensor = self.store.get(MyFeatureStore.key(attr), None) + if tensor is None: + return None + if attr.index is not None: + indices = torch.cat([(tensor[:, 0] == v).nonzero() + for v in attr.index]).reshape(1, -1)[0] + + return torch.index_select(tensor[:, 1:], 0, indices) + return tensor[:, 1:] + + def _remove_tensor(self, attr: TensorAttr) -> bool: + if attr.index is not None: + raise NotImplementedError + del self.store[MyFeatureStore.key(attr)] + + def __len__(self): + raise NotImplementedError + + +def test_feature_store(): + r"""Tests basic API and indexing functionality of a feature store.""" + + store = MyFeatureStore() + tensor = torch.Tensor([[0, 0, 0], [1, 1, 1], [2, 2, 2]]) + index = torch.Tensor([0, 1, 2]) + + tensor_type = 'feat' + node_type = 'A' + graph_type = 'graph' + + attr = TensorAttr(index, tensor_type, node_type, graph_type) + + # Normal API + store.put_tensor(tensor, attr) + assert torch.equal(store.get_tensor(attr), tensor) + assert torch.equal( + store.get_tensor( + (torch.Tensor([0, 2]), tensor_type, node_type, graph_type)), + tensor[[0, 2]], + ) + assert store.get_tensor((index)) is None + store.remove_tensor((None, tensor_type, node_type, graph_type)) + assert store.get_tensor(attr) is None + + # Indexing + store[attr] = tensor + assert torch.equal(store[attr], tensor) + assert torch.equal( + store[(torch.Tensor([0, 2]), tensor_type, node_type, graph_type)], + tensor[[0, 2]], + ) + assert store[(index)] is None + del store[(None, tensor_type, node_type, graph_type)] + assert store.get_tensor(attr) is None + + # Advanced indexing + store[attr] = tensor + assert (torch.equal( + store[TensorAttr(node_type=node_type, + graph_type=graph_type)].feat[index], tensor)) diff --git a/torch_geometric/data/feature_store.py b/torch_geometric/data/feature_store.py new file mode 100644 index 000000000000..84ad0bcd751a --- /dev/null +++ b/torch_geometric/data/feature_store.py @@ -0,0 +1,218 @@ +r""" +This class defines the abstraction for a Graph feature store. The goal of a +feature store is to abstract away all node and edge feature memory management +so that varying implementations can allow for independent scale-out. + +This particular feature store abstraction makes a few key assumptions: + * The features we care about storing are all associated with some sort of + `index`; explicitly for PyG the the index of the node in the graph (or + the heterogeneous component of the graph it resides in). + * A feature can uniquely be identified from (a) its index and (b) any other + associated attributes specified in :obj:`TensorAttr`. + +It is the job of a feature store implementor class to handle these assumptions +properly. For example, a simple in-memory feature store implementation may +concatenate all metadata values with a feature index and use this as a unique +index in a KV store. More complicated implementations may choose to partition +features in interesting manners based on the provided metadata. + +Major TODOs for future implementation: +* Async `put` and `get` functionality +""" +from abc import abstractmethod +from collections.abc import MutableMapping +from dataclasses import dataclass +from typing import Any, Optional + +import numpy as np +import torch + +from torch_geometric.typing import TensorType +from torch_geometric.utils.mixin import CastMixin + + +@dataclass +class TensorAttr(CastMixin): + r"""Defines the attributes of a :obj:`FeatureStore` tensor.""" + + # The node indices the rows of the tensor correspond to + index: Optional[TensorType] = None + + # The type of the feature tensor (may be used if there are multiple + # different feature tensors for the same node index) + tensor_type: Optional[str] = None + + # The type of the nodes that the tensor corresponds to (may be used for + # hetereogeneous graphs) + node_type: Optional[str] = None + + # The type of the graph that the nodes correspond to (may be used if a + # feature store supports multiple graphs) + graph_type: Optional[str] = None + + +class AttrView: + r"""A view of a :obj:`FeatureStore` that is obtained from an incomplete + specification of attributes. This view stores a reference to the + originating feature store as well as a :obj:`TensorAttr` object that + represents the view's (incomplete) state. + + As a result, store[TensorAttr(...)].tensor_type[idx] allows for indexing + into the store. + """ + _store: 'FeatureStore' + attr: TensorAttr + + def __init__(self, store, attr): + self._store = store + self.attr = attr + + def __getattr__(self, tensor_type): + r"""Supports attr_view.attr""" + self.attr.tensor_type = tensor_type + return self + + def __getitem__(self, index: TensorType): + r"""Supports attr_view.attr[idx]""" + self.attr.index = index + return self._store.get_tensor(self.attr) + + def __repr__(self) -> str: + return f'AttrView(store={self._store}, attr={self.attr})' + + +class FeatureStore(MutableMapping): + def __init__(self, backend: Any): + r"""Initializes the feature store with a specified backend.""" + self.backend = backend + + # Core (CRUD) ############################################################# + + @abstractmethod + def _put_tensor(self, tensor: TensorType, attr: TensorAttr) -> bool: + r"""Implemented by :obj:`FeatureStore` subclasses.""" + pass + + def put_tensor(self, tensor: TensorType, attr: TensorAttr) -> bool: + r"""Synchronously adds a :obj:`TensorType` object to the feature store. + + Args: + tensor (TensorType): the features to be added. + attr (TensorAttr): any relevant tensor attributes that correspond + to the feature tensor. See the :obj:`TensorAttr` documentation + for required and optional attributes. It is the job of + implementations of a FeatureStore to store this metadata in a + meaningful way that allows for tensor retrieval from a + :obj:`TensorAttr` object. + Returns: + bool: whether insertion was successful. + """ + attr = TensorAttr.cast(attr) + assert attr.index is not None + assert attr.index.size(dim=0) == tensor.size(dim=-1) + return self._put_tensor(tensor, attr) + + @abstractmethod + def _get_tensor(self, attr: TensorAttr) -> Optional[TensorType]: + r"""Implemented by :obj:`FeatureStore` subclasses.""" + pass + + def get_tensor(self, attr: TensorAttr) -> Optional[TensorType]: + r"""Synchronously obtains a :obj:`TensorType` object from the feature + store. Feature store implementors guarantee that the call + get_tensor(put_tensor(tensor, attr), attr) = tensor. + + Args: + attr (TensorAttr): any relevant tensor attributes that correspond + to the tensor to obtain. See :obj:`TensorAttr` documentation + for required and optional attributes. It is the job of + implementations of a FeatureStore to store this metadata in a + meaningful way that allows for tensor retrieval from a + :obj:`TensorAttr` object. + Returns: + TensorType, optional: a tensor of the same type as the index, or + None if no tensor was found. + """ + def maybe_to_torch(x): + return torch.from_numpy(x) if isinstance( + attr.index, torch.Tensor) and isinstance(x, np.ndarray) else x + + attr = TensorAttr.cast(attr) + assert attr.index is not None + + return maybe_to_torch(self._get_tensor(attr)) + + @abstractmethod + def _remove_tensor(self, attr: TensorAttr) -> bool: + r"""Implemented by :obj:`FeatureStore` subclasses.""" + pass + + def remove_tensor(self, attr: TensorAttr) -> bool: + r"""Removes a :obj:`TensorType` object from the feature store. + + Args: + attr (TensorAttr): any relevant tensor attributes that correspond + to the tensor to remove. See :obj:`TensorAttr` documentation + for required and optional attributes. It is the job of + implementations of a FeatureStore to store this metadata in a + meaningful way that allows for tensor deletion from a + :obj:`TensorAttr` object. + + Returns: + bool: whether deletion was succesful. + """ + attr = TensorAttr.cast(attr) + self._remove_tensor(attr) + + def update_tensor(self, tensor: TensorType, attr: TensorAttr) -> bool: + r"""Updates a :obj:`TensorType` object with a new value. Implementor + classes can choose to define more efficient update methods; the default + performs a removal and insertion. + + Args: + tensor (TensorType): the features to be added. + attr (TensorAttr): any relevant tensor attributes that correspond + to the old tensor. See :obj:`TensorAttr` documentation + for required and optional attributes. It is the job of + implementations of a FeatureStore to store this metadata in a + meaningful way that allows for tensor update from a + :obj:`TensorAttr` object. + + Returns: + bool: whether the update was succesful. + """ + attr = TensorAttr.cast(attr) + self.remove_tensor(attr) + return self.put_tensor(tensor, attr) + + # Python built-ins ######################################################## + + def __repr__(self) -> str: + return f'{self.__class__.__name__}(backend={self.backend})' + + def __setitem__(self, key: TensorAttr, value: TensorType): + r"""Supports store[tensor_attr] = tensor.""" + key = TensorAttr.cast(key) + assert key.index is not None + self.put_tensor(value, key) + + def __getitem__(self, key: TensorAttr): + r"""Supports store[tensor_attr]. If tensor_attr has index specified, + will obtain the corresponding features from the store. Otherwise, will + return an :obj:`AttrView` which can be indexed independently.""" + key = TensorAttr.cast(key) + if key.index is not None: + return self.get_tensor(key) + return AttrView(self, key) + + def __delitem__(self, key: TensorAttr): + r"""Supports del store[tensor_attr].""" + key = TensorAttr.cast(key) + self.remove_tensor(key) + + def __iter__(self): + raise NotImplementedError + + @abstractmethod + def __len__(self): + pass diff --git a/torch_geometric/typing.py b/torch_geometric/typing.py index cb672e6aef51..57df5a7ad679 100644 --- a/torch_geometric/typing.py +++ b/torch_geometric/typing.py @@ -1,5 +1,7 @@ from typing import Dict, List, Optional, Tuple, Union +import numpy +import torch from torch import Tensor from torch_sparse import SparseTensor @@ -20,6 +22,9 @@ Metadata = Tuple[List[NodeType], List[EdgeType]] +# A representation of a tensor, either as a torch Tensor or a numpy ndarray +TensorType = Union[torch.TensorType, numpy.ndarray] + # Types for message passing ################################################### Adj = Union[Tensor, SparseTensor] diff --git a/torch_geometric/utils/mixin.py b/torch_geometric/utils/mixin.py new file mode 100644 index 000000000000..7f14a10a2dad --- /dev/null +++ b/torch_geometric/utils/mixin.py @@ -0,0 +1,14 @@ +class CastMixin: + @classmethod + def cast(cls, *args, **kwargs): # TODO Can we apply this recursively? + if len(args) == 1 and len(kwargs) == 0: + elem = args[0] + if elem is None: + return None + if isinstance(elem, CastMixin): + return elem + if isinstance(elem, (tuple, list)): + return cls(*elem) + if isinstance(elem, dict): + return cls(**elem) + return cls(*args, **kwargs) From e962338db6f79d103b6f13c1d4b2c9e0e69ba0aa Mon Sep 17 00:00:00 2001 From: Manan Shah Date: Mon, 25 Apr 2022 21:59:04 +0000 Subject: [PATCH 02/18] TensorType -> FeatureTensorType --- test/data/test_feature_store.py | 6 ++-- torch_geometric/data/feature_store.py | 42 ++++++++++++++------------- torch_geometric/typing.py | 4 +-- 3 files changed, 27 insertions(+), 25 deletions(-) diff --git a/test/data/test_feature_store.py b/test/data/test_feature_store.py index 91e958ef8fcf..6bbe065a4c3c 100644 --- a/test/data/test_feature_store.py +++ b/test/data/test_feature_store.py @@ -3,7 +3,7 @@ import torch from torch_geometric.data.feature_store import FeatureStore, TensorAttr -from torch_geometric.typing import TensorType +from torch_geometric.typing import FeatureTensorType class MyFeatureStore(FeatureStore): @@ -18,12 +18,12 @@ def key(cls, attr: TensorAttr): return (attr.tensor_type or '', attr.node_type or '', attr.graph_type or '') - def _put_tensor(self, tensor: TensorType, attr: TensorAttr) -> bool: + def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool: self.store[MyFeatureStore.key(attr)] = torch.cat( (attr.index.reshape(-1, 1), tensor), dim=1) return True - def _get_tensor(self, attr: TensorAttr) -> Optional[TensorType]: + def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]: tensor = self.store.get(MyFeatureStore.key(attr), None) if tensor is None: return None diff --git a/torch_geometric/data/feature_store.py b/torch_geometric/data/feature_store.py index 84ad0bcd751a..13cede4d74cd 100644 --- a/torch_geometric/data/feature_store.py +++ b/torch_geometric/data/feature_store.py @@ -27,7 +27,7 @@ import numpy as np import torch -from torch_geometric.typing import TensorType +from torch_geometric.typing import FeatureTensorType from torch_geometric.utils.mixin import CastMixin @@ -36,7 +36,7 @@ class TensorAttr(CastMixin): r"""Defines the attributes of a :obj:`FeatureStore` tensor.""" # The node indices the rows of the tensor correspond to - index: Optional[TensorType] = None + index: Optional[FeatureTensorType] = None # The type of the feature tensor (may be used if there are multiple # different feature tensors for the same node index) @@ -72,7 +72,7 @@ def __getattr__(self, tensor_type): self.attr.tensor_type = tensor_type return self - def __getitem__(self, index: TensorType): + def __getitem__(self, index: FeatureTensorType): r"""Supports attr_view.attr[idx]""" self.attr.index = index return self._store.get_tensor(self.attr) @@ -89,15 +89,16 @@ def __init__(self, backend: Any): # Core (CRUD) ############################################################# @abstractmethod - def _put_tensor(self, tensor: TensorType, attr: TensorAttr) -> bool: + def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool: r"""Implemented by :obj:`FeatureStore` subclasses.""" pass - def put_tensor(self, tensor: TensorType, attr: TensorAttr) -> bool: - r"""Synchronously adds a :obj:`TensorType` object to the feature store. + def put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool: + r"""Synchronously adds a :obj:`FeatureTensorType` object to the feature + store. Args: - tensor (TensorType): the features to be added. + tensor (FeatureTensorType): the features to be added. attr (TensorAttr): any relevant tensor attributes that correspond to the feature tensor. See the :obj:`TensorAttr` documentation for required and optional attributes. It is the job of @@ -113,13 +114,13 @@ def put_tensor(self, tensor: TensorType, attr: TensorAttr) -> bool: return self._put_tensor(tensor, attr) @abstractmethod - def _get_tensor(self, attr: TensorAttr) -> Optional[TensorType]: + def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]: r"""Implemented by :obj:`FeatureStore` subclasses.""" pass - def get_tensor(self, attr: TensorAttr) -> Optional[TensorType]: - r"""Synchronously obtains a :obj:`TensorType` object from the feature - store. Feature store implementors guarantee that the call + def get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]: + r"""Synchronously obtains a :obj:`FeatureTensorType` object from the + feature store. Feature store implementors guarantee that the call get_tensor(put_tensor(tensor, attr), attr) = tensor. Args: @@ -130,8 +131,8 @@ def get_tensor(self, attr: TensorAttr) -> Optional[TensorType]: meaningful way that allows for tensor retrieval from a :obj:`TensorAttr` object. Returns: - TensorType, optional: a tensor of the same type as the index, or - None if no tensor was found. + FeatureTensorType, optional: a tensor of the same type as the + index, or None if no tensor was found. """ def maybe_to_torch(x): return torch.from_numpy(x) if isinstance( @@ -148,7 +149,7 @@ def _remove_tensor(self, attr: TensorAttr) -> bool: pass def remove_tensor(self, attr: TensorAttr) -> bool: - r"""Removes a :obj:`TensorType` object from the feature store. + r"""Removes a :obj:`FeatureTensorType` object from the feature store. Args: attr (TensorAttr): any relevant tensor attributes that correspond @@ -164,13 +165,14 @@ def remove_tensor(self, attr: TensorAttr) -> bool: attr = TensorAttr.cast(attr) self._remove_tensor(attr) - def update_tensor(self, tensor: TensorType, attr: TensorAttr) -> bool: - r"""Updates a :obj:`TensorType` object with a new value. Implementor - classes can choose to define more efficient update methods; the default - performs a removal and insertion. + def update_tensor(self, tensor: FeatureTensorType, + attr: TensorAttr) -> bool: + r"""Updates a :obj:`FeatureTensorType` object with a new value. + implementor classes can choose to define more efficient update methods; + the default performs a removal and insertion. Args: - tensor (TensorType): the features to be added. + tensor (FeatureTensorType): the features to be added. attr (TensorAttr): any relevant tensor attributes that correspond to the old tensor. See :obj:`TensorAttr` documentation for required and optional attributes. It is the job of @@ -190,7 +192,7 @@ def update_tensor(self, tensor: TensorType, attr: TensorAttr) -> bool: def __repr__(self) -> str: return f'{self.__class__.__name__}(backend={self.backend})' - def __setitem__(self, key: TensorAttr, value: TensorType): + def __setitem__(self, key: TensorAttr, value: FeatureTensorType): r"""Supports store[tensor_attr] = tensor.""" key = TensorAttr.cast(key) assert key.index is not None diff --git a/torch_geometric/typing.py b/torch_geometric/typing.py index 57df5a7ad679..4fdec3eccd87 100644 --- a/torch_geometric/typing.py +++ b/torch_geometric/typing.py @@ -22,8 +22,8 @@ Metadata = Tuple[List[NodeType], List[EdgeType]] -# A representation of a tensor, either as a torch Tensor or a numpy ndarray -TensorType = Union[torch.TensorType, numpy.ndarray] +# A representation of a feature tensor +FeatureTensorType = Union[torch.TensorType, numpy.ndarray] # Types for message passing ################################################### From b347a3c3846339689d25e8a7b05aa073140d90c9 Mon Sep 17 00:00:00 2001 From: Manan Shah Date: Tue, 26 Apr 2022 00:46:49 +0000 Subject: [PATCH 03/18] to_type --- torch_geometric/data/feature_store.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/torch_geometric/data/feature_store.py b/torch_geometric/data/feature_store.py index 13cede4d74cd..732d79353a77 100644 --- a/torch_geometric/data/feature_store.py +++ b/torch_geometric/data/feature_store.py @@ -134,14 +134,19 @@ def get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]: FeatureTensorType, optional: a tensor of the same type as the index, or None if no tensor was found. """ - def maybe_to_torch(x): - return torch.from_numpy(x) if isinstance( - attr.index, torch.Tensor) and isinstance(x, np.ndarray) else x + def to_type(tensor): + if isinstance(attr.index, torch.Tensor): + return torch.from_numpy(tensor) if isinstance( + tensor, np.ndarray) else tensor + if isinstance(attr.index, np.ndarray): + return tensor.numpy() if isinstance(tensor, + torch.Tensor) else tensor + raise ValueError attr = TensorAttr.cast(attr) assert attr.index is not None - return maybe_to_torch(self._get_tensor(attr)) + return to_type(self._get_tensor(attr)) @abstractmethod def _remove_tensor(self, attr: TensorAttr) -> bool: From a0486b983b8b6d8216a40a45bb36a47abfd6c012 Mon Sep 17 00:00:00 2001 From: Manan Shah Date: Tue, 26 Apr 2022 00:47:52 +0000 Subject: [PATCH 04/18] fix --- torch_geometric/data/feature_store.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch_geometric/data/feature_store.py b/torch_geometric/data/feature_store.py index 732d79353a77..e5d8a6032036 100644 --- a/torch_geometric/data/feature_store.py +++ b/torch_geometric/data/feature_store.py @@ -135,6 +135,8 @@ def get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]: index, or None if no tensor was found. """ def to_type(tensor): + if tensor is None: + return None if isinstance(attr.index, torch.Tensor): return torch.from_numpy(tensor) if isinstance( tensor, np.ndarray) else tensor From e6023251391b165aebbdcceff0eadc1736d0bb06 Mon Sep 17 00:00:00 2001 From: Manan Shah Date: Wed, 27 Apr 2022 02:16:23 +0000 Subject: [PATCH 05/18] API changes, WIP --- test/data/test_feature_store.py | 80 +++++++----- torch_geometric/data/feature_store.py | 172 ++++++++++++++++++-------- torch_geometric/loader/utils.py | 2 + 3 files changed, 174 insertions(+), 80 deletions(-) diff --git a/test/data/test_feature_store.py b/test/data/test_feature_store.py index 6bbe065a4c3c..a6cebead41b5 100644 --- a/test/data/test_feature_store.py +++ b/test/data/test_feature_store.py @@ -2,7 +2,12 @@ import torch -from torch_geometric.data.feature_store import FeatureStore, TensorAttr +from torch_geometric.data.feature_store import ( + AttrView, + FeatureStore, + TensorAttr, + _field_status, +) from torch_geometric.typing import FeatureTensorType @@ -15,19 +20,21 @@ def __init__(self): @classmethod def key(cls, attr: TensorAttr): - return (attr.tensor_type or '', attr.node_type or '', attr.graph_type - or '') + return (attr.group_name or '', attr.attr_name or '') def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool: + index = attr.index + if index is None or index is _field_status.UNSET: + index = torch.range(0, tensor.shape[0] - 1) self.store[MyFeatureStore.key(attr)] = torch.cat( - (attr.index.reshape(-1, 1), tensor), dim=1) + (index.reshape(-1, 1), tensor), dim=1) return True def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]: tensor = self.store.get(MyFeatureStore.key(attr), None) - if tensor is None: + if tensor is None or tensor is _field_status.UNSET: return None - if attr.index is not None: + if attr.index is not None and attr.index is not _field_status.UNSET: indices = torch.cat([(tensor[:, 0] == v).nonzero() for v in attr.index]).reshape(1, -1)[0] @@ -35,8 +42,6 @@ def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]: return tensor[:, 1:] def _remove_tensor(self, attr: TensorAttr) -> bool: - if attr.index is not None: - raise NotImplementedError del self.store[MyFeatureStore.key(attr)] def __len__(self): @@ -50,37 +55,52 @@ def test_feature_store(): tensor = torch.Tensor([[0, 0, 0], [1, 1, 1], [2, 2, 2]]) index = torch.Tensor([0, 1, 2]) - tensor_type = 'feat' - node_type = 'A' - graph_type = 'graph' + attr_name = 'feat' + group_name = 'A' - attr = TensorAttr(index, tensor_type, node_type, graph_type) + attr = TensorAttr(group_name, attr_name, index) # Normal API store.put_tensor(tensor, attr) assert torch.equal(store.get_tensor(attr), tensor) assert torch.equal( - store.get_tensor( - (torch.Tensor([0, 2]), tensor_type, node_type, graph_type)), + store.get_tensor((group_name, attr_name, torch.Tensor([0, 2]))), tensor[[0, 2]], ) - assert store.get_tensor((index)) is None - store.remove_tensor((None, tensor_type, node_type, graph_type)) + assert store.get_tensor(TensorAttr(index=index)) is None + store.remove_tensor(TensorAttr(group_name, attr_name)) assert store.get_tensor(attr) is None + # Views + view = store.view(TensorAttr(group_name=group_name)) + view.attr_name = attr_name + view.index = index + assert view == AttrView(store, TensorAttr(group_name, attr_name, index)) + # Indexing - store[attr] = tensor - assert torch.equal(store[attr], tensor) - assert torch.equal( - store[(torch.Tensor([0, 2]), tensor_type, node_type, graph_type)], - tensor[[0, 2]], - ) - assert store[(index)] is None - del store[(None, tensor_type, node_type, graph_type)] - assert store.get_tensor(attr) is None - # Advanced indexing - store[attr] = tensor - assert (torch.equal( - store[TensorAttr(node_type=node_type, - graph_type=graph_type)].feat[index], tensor)) + # Setting via indexing + store[group_name, attr_name, index] = tensor + + # Fully-specified forms, all of which produce a tensor output + assert torch.equal(store[group_name, attr_name, index], tensor) + assert torch.equal(store[group_name, attr_name, None], tensor) + assert torch.equal(store[group_name, attr_name, :], tensor) + assert torch.equal(store[group_name].feat[:], tensor) + + # Partially-specified forms, which produce an AttrView object + assert store[group_name] == store.view(TensorAttr(group_name=group_name)) + assert store[group_name].feat == store.view( + TensorAttr(group_name=group_name, attr_name=attr_name)) + + # Partially-specified forms, when called, produce a Tensor output + # from the `TensorAttr` that has been partially specified. + store[group_name] = tensor + assert isinstance(store[group_name], AttrView) + assert torch.equal(store[group_name](), tensor) + + # Deletion + del store[group_name, attr_name, index] + assert store[group_name, attr_name, index] is None + del store[group_name] + assert store[group_name]() is None diff --git a/torch_geometric/data/feature_store.py b/torch_geometric/data/feature_store.py index e5d8a6032036..8d4e779514d4 100644 --- a/torch_geometric/data/feature_store.py +++ b/torch_geometric/data/feature_store.py @@ -4,11 +4,13 @@ so that varying implementations can allow for independent scale-out. This particular feature store abstraction makes a few key assumptions: - * The features we care about storing are all associated with some sort of - `index`; explicitly for PyG the the index of the node in the graph (or - the heterogeneous component of the graph it resides in). - * A feature can uniquely be identified from (a) its index and (b) any other - associated attributes specified in :obj:`TensorAttr`. + * The features we care about storing are graph node and edge features. To + this end, the attributes that the feature store supports include a + group_name (e.g. a heterogeneous node name, a heterogeneous edge type, + etc.), an attr_name (which defines the name of the feature tensor, + e.g. `feat`, `discrete_feat`, etc.), and an index. + * A feature can uniquely be identified from any associated attributes + specified in :obj:`TensorAttr`. It is the job of a feature store implementor class to handle these assumptions properly. For example, a simple in-memory feature store implementation may @@ -22,7 +24,8 @@ from abc import abstractmethod from collections.abc import MutableMapping from dataclasses import dataclass -from typing import Any, Optional +from enum import Enum +from typing import Any, Optional, Union import numpy as np import torch @@ -30,55 +33,106 @@ from torch_geometric.typing import FeatureTensorType from torch_geometric.utils.mixin import CastMixin +_field_status = Enum("FieldStatus", "UNSET") -@dataclass -class TensorAttr(CastMixin): - r"""Defines the attributes of a :obj:`FeatureStore` tensor.""" +IndexType = Union[FeatureTensorType, slice] - # The node indices the rows of the tensor correspond to - index: Optional[FeatureTensorType] = None - # The type of the feature tensor (may be used if there are multiple - # different feature tensors for the same node index) - tensor_type: Optional[str] = None +@dataclass +class TensorAttr(CastMixin): + r"""Defines the attributes of a :obj:`FeatureStore` tensor; in particular, + all the parameters necessary to uniquely identify a tensor from the feature + store. + + Note that the order of the attributes is important; this is the order in + which attributes must be provided for indexing calls. Feature store + implementor classes can define a different ordering by overriding + TensorAttr.__init__. + """ # The type of the nodes that the tensor corresponds to (may be used for # hetereogeneous graphs) - node_type: Optional[str] = None - - # The type of the graph that the nodes correspond to (may be used if a - # feature store supports multiple graphs) - graph_type: Optional[str] = None - + group_name: Optional[str] = _field_status.UNSET -class AttrView: - r"""A view of a :obj:`FeatureStore` that is obtained from an incomplete - specification of attributes. This view stores a reference to the - originating feature store as well as a :obj:`TensorAttr` object that - represents the view's (incomplete) state. + # The name of the feature tensor (may be used if there are multiple + # different feature tensors for the same node index) + attr_name: Optional[str] = _field_status.UNSET - As a result, store[TensorAttr(...)].tensor_type[idx] allows for indexing - into the store. + # The node indices the rows of the tensor correspond to + index: Optional[IndexType] = _field_status.UNSET + + def is_fully_specified(self): + r"""Whether the :obj:`TensorAttr` has no UNSET fields.""" + return all([ + getattr(self, field) != _field_status.UNSET + for field in self.__dataclass_fields__ + ]) + + def update(self, attr: 'TensorAttr'): + r"""Updates an :obj:`TensorAttr` with attributes from another + :obj:`TensorAttr`.""" + for field in self.__dataclass_fields__: + val = getattr(attr, field) + if val != _field_status.UNSET: + setattr(self, field, val) + + +class AttrView(CastMixin): + r"""Defines a view of a :obj:`FeatureStore` that is obtained from a + specification of attributes on the feature store. The view stores a + reference to the backing feature store as well as a :obj:`TensorAttr` + object that represents the view's state. + + Users can create views either using the :obj:`AttrView` constructor, + :obj:`FeatureStore.view`, or by incompletely indexing a feature store. """ _store: 'FeatureStore' - attr: TensorAttr + _attr: TensorAttr - def __init__(self, store, attr): + def __init__(self, store: 'FeatureStore', attr: TensorAttr): self._store = store - self.attr = attr + self._attr = attr + + def __getattr__(self, key): + r"""Sets the attr_name field of the backing :obj:`TensorAttr` object to + the attribute. In particular, this allows for :obj:`AttrView` to be + indexed by different values of attr_name.""" + if key in ['_attr', '_store']: + return super(AttrView, self).__getattribute__(key) + + self._attr.attr_name = key + if self._attr.is_fully_specified(): + return self._store.get_tensor(self._attr) + return self + + def __setattr__(self, key, value): + r"""Supports attribute assignment to the backing :obj:`TensorAttr` of + an :obj:`AttrView`.""" + if key in ['_attr', '_store']: + return super(AttrView, self).__setattr__(key, value) - def __getattr__(self, tensor_type): - r"""Supports attr_view.attr""" - self.attr.tensor_type = tensor_type + TensorAttr.__setattr__(self._attr, key, value) return self - def __getitem__(self, index: FeatureTensorType): - r"""Supports attr_view.attr[idx]""" - self.attr.index = index - return self._store.get_tensor(self.attr) + def __getitem__(self, index: IndexType): + r"""Supports indexing the backing :obj:`TensorAttr` object by an + index or a slice.""" + self._attr.index = index + if self._attr.is_fully_specified(): + return self._store.get_tensor(self._attr) + + def __call__(self) -> FeatureTensorType: + r"""Supports :obj:`AttrView` as a callable to force retrieval""" + return self._store.get_tensor(self._attr) + + def __eq__(self, __o: object) -> bool: + if not isinstance(__o, AttrView): + return False + + return id(self._store) == id(__o._store) and self._attr == __o._attr def __repr__(self) -> str: - return f'AttrView(store={self._store}, attr={self.attr})' + return f"AttrView(store={self._store}, attr={self._attr})" class FeatureStore(MutableMapping): @@ -109,8 +163,6 @@ def put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool: bool: whether insertion was successful. """ attr = TensorAttr.cast(attr) - assert attr.index is not None - assert attr.index.size(dim=0) == tensor.size(dim=-1) return self._put_tensor(tensor, attr) @abstractmethod @@ -143,10 +195,13 @@ def to_type(tensor): if isinstance(attr.index, np.ndarray): return tensor.numpy() if isinstance(tensor, torch.Tensor) else tensor - raise ValueError + return tensor attr = TensorAttr.cast(attr) - assert attr.index is not None + if isinstance(attr.index, + slice) and (attr.index.start, attr.index.stop, + attr.index.step) == (None, None, None): + attr.index = None return to_type(self._get_tensor(attr)) @@ -194,6 +249,13 @@ def update_tensor(self, tensor: FeatureTensorType, self.remove_tensor(attr) return self.put_tensor(tensor, attr) + # :obj:`AttrView` methods ################################################# + + def view(self, attr: Optional[TensorAttr]) -> AttrView: + r"""Returns an :obj:`AttrView` of the feature store, with the defined + attributes set.""" + return AttrView(self, TensorAttr.cast(attr)) + # Python built-ins ######################################################## def __repr__(self) -> str: @@ -202,17 +264,27 @@ def __repr__(self) -> str: def __setitem__(self, key: TensorAttr, value: FeatureTensorType): r"""Supports store[tensor_attr] = tensor.""" key = TensorAttr.cast(key) - assert key.index is not None self.put_tensor(value, key) def __getitem__(self, key: TensorAttr): - r"""Supports store[tensor_attr]. If tensor_attr has index specified, - will obtain the corresponding features from the store. Otherwise, will - return an :obj:`AttrView` which can be indexed independently.""" - key = TensorAttr.cast(key) - if key.index is not None: - return self.get_tensor(key) - return AttrView(self, key) + r"""Supports pythonic indexing into the feature store. In particular, + the following rules are followed for indexing: + + * Fully-specified indexes will produce a Tensor output. A + fully-specified index specifies all the required attributes in + :obj:`TensorAttr`. + + * Partially-specified indexes will produce an AttrView output, which + is a view on the FeatureStore. If a view is called, it will produce + a Tensor output from the corresponding (partially specified) + attributes. + """ + # CastMixin will handle the case of key being a tuple or TensorAttr + # object. + attr = TensorAttr.cast(key) + if attr.is_fully_specified(): + return self.get_tensor(attr) + return AttrView(self, attr) def __delitem__(self, key: TensorAttr): r"""Supports del store[tensor_attr].""" diff --git a/torch_geometric/loader/utils.py b/torch_geometric/loader/utils.py index a7b783727477..5dc488791c23 100644 --- a/torch_geometric/loader/utils.py +++ b/torch_geometric/loader/utils.py @@ -42,12 +42,14 @@ def to_csc( # `perm` can be of type `None`. perm: Optional[Tensor] = None + print("HERE") if hasattr(data, 'adj_t'): colptr, row, _ = data.adj_t.csr() elif hasattr(data, 'edge_index'): (row, col) = data.edge_index size = data.size() + print('edge index is ', data.edge_index.size(), ' and size is ', size) perm = (col * size[0]).add_(row).argsort() colptr = torch.ops.torch_sparse.ind2ptr(col[perm], size[1]) row = row[perm] From e4f4f420563bdd0c2bdd27893fe8e4c27780fef7 Mon Sep 17 00:00:00 2001 From: Manan Shah Date: Wed, 27 Apr 2022 02:24:42 +0000 Subject: [PATCH 06/18] Fix --- torch_geometric/loader/utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torch_geometric/loader/utils.py b/torch_geometric/loader/utils.py index 5dc488791c23..a7b783727477 100644 --- a/torch_geometric/loader/utils.py +++ b/torch_geometric/loader/utils.py @@ -42,14 +42,12 @@ def to_csc( # `perm` can be of type `None`. perm: Optional[Tensor] = None - print("HERE") if hasattr(data, 'adj_t'): colptr, row, _ = data.adj_t.csr() elif hasattr(data, 'edge_index'): (row, col) = data.edge_index size = data.size() - print('edge index is ', data.edge_index.size(), ' and size is ', size) perm = (col * size[0]).add_(row).argsort() colptr = torch.ops.torch_sparse.ind2ptr(col[perm], size[1]) row = row[perm] From 99b625a943c31617e7b1f66fa5f378ab36e5dc77 Mon Sep 17 00:00:00 2001 From: Manan Shah Date: Thu, 28 Apr 2022 01:05:56 +0000 Subject: [PATCH 07/18] Updates --- test/data/test_feature_store.py | 2 +- torch_geometric/data/feature_store.py | 88 ++++++++++++++++++++++----- 2 files changed, 75 insertions(+), 15 deletions(-) diff --git a/test/data/test_feature_store.py b/test/data/test_feature_store.py index a6cebead41b5..c69753c6af36 100644 --- a/test/data/test_feature_store.py +++ b/test/data/test_feature_store.py @@ -32,7 +32,7 @@ def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool: def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]: tensor = self.store.get(MyFeatureStore.key(attr), None) - if tensor is None or tensor is _field_status.UNSET: + if tensor is None: return None if attr.index is not None and attr.index is not _field_status.UNSET: indices = torch.cat([(tensor[:, 0] == v).nonzero() diff --git a/torch_geometric/data/feature_store.py b/torch_geometric/data/feature_store.py index 8d4e779514d4..59ef3b014453 100644 --- a/torch_geometric/data/feature_store.py +++ b/torch_geometric/data/feature_store.py @@ -9,7 +9,7 @@ group_name (e.g. a heterogeneous node name, a heterogeneous edge type, etc.), an attr_name (which defines the name of the feature tensor, e.g. `feat`, `discrete_feat`, etc.), and an index. - * A feature can uniquely be identified from any associated attributes + * A feature can be uniquely identified from any associated attributes specified in :obj:`TensorAttr`. It is the job of a feature store implementor class to handle these assumptions @@ -61,15 +61,22 @@ class TensorAttr(CastMixin): # The node indices the rows of the tensor correspond to index: Optional[IndexType] = _field_status.UNSET + # Convenience methods ##################################################### + + def is_set(self, attr): + r"""Whether an attribute is set in :obj:`TensorAttr`.""" + assert attr in self.__dataclass_fields__ + return getattr(self, attr) != _field_status.UNSET + def is_fully_specified(self): - r"""Whether the :obj:`TensorAttr` has no UNSET fields.""" + r"""Whether the :obj:`TensorAttr` has no unset fields.""" return all([ getattr(self, field) != _field_status.UNSET for field in self.__dataclass_fields__ ]) def update(self, attr: 'TensorAttr'): - r"""Updates an :obj:`TensorAttr` with attributes from another + r"""Updates an :obj:`TensorAttr` with set attributes from another :obj:`TensorAttr`.""" for field in self.__dataclass_fields__: val = getattr(attr, field) @@ -93,10 +100,29 @@ def __init__(self, store: 'FeatureStore', attr: TensorAttr): self._store = store self._attr = attr - def __getattr__(self, key): + # Properties ############################################################## + + @property + def attr(self) -> TensorAttr: + return self._attr + + @property + def store(self) -> 'FeatureStore': + return self._store + + # Python built-ins ######################################################## + + def __getattr__(self, key) -> 'AttrView': r"""Sets the attr_name field of the backing :obj:`TensorAttr` object to - the attribute. In particular, this allows for :obj:`AttrView` to be - indexed by different values of attr_name.""" + the attribute. This allows for :obj:`AttrView` to be indexed by + different values of attr_name. In particular, for a feature store that + has `feat` as an `attr_name`, the following code indexes into `feat`: + + .. code-block:: python + + store[group_name].feat[:] + + """ if key in ['_attr', '_store']: return super(AttrView, self).__getattribute__(key) @@ -107,25 +133,59 @@ def __getattr__(self, key): def __setattr__(self, key, value): r"""Supports attribute assignment to the backing :obj:`TensorAttr` of - an :obj:`AttrView`.""" + an :obj:`AttrView`. This allows for :obj:`AttrView` objects to set + their backing attribute values. In particular, the following operation + sets the `index` of an :obj:`AttrView`: + + .. code-block:: python + + view = store.view(TensorAttr(group_name)) + view.index = torch.Tensor([1, 2, 3]) + + """ if key in ['_attr', '_store']: return super(AttrView, self).__setattr__(key, value) TensorAttr.__setattr__(self._attr, key, value) - return self - def __getitem__(self, index: IndexType): + def __getitem__( + self, + index: IndexType, + ) -> Union['AttrView', FeatureTensorType]: r"""Supports indexing the backing :obj:`TensorAttr` object by an - index or a slice.""" + index or a slice. If the index operation results in a fully-specified + :obj:`AttrView`, a Tensor is returned. Otherwise, the :obj:`AttrView` + object is returned. The following operation returns a Tensor object + as a result of the index specification: + + .. code-block:: python + + store[group_name, attr_name][:] + + """ self._attr.index = index if self._attr.is_fully_specified(): return self._store.get_tensor(self._attr) + return self def __call__(self) -> FeatureTensorType: - r"""Supports :obj:`AttrView` as a callable to force retrieval""" + r"""Supports :obj:`AttrView` as a callable to force retrieval from + the currently specified attributes. In particular, this passes the + current :obj:`TensorAttr` object to a GET call, regardless of whether + all attributes have been specified. It returns the result of this + call. In particular, the following operation returns a Tensor by + performing a GET operation on the backing feature store: + + .. code-block:: python + + store[group_name, attr_name]() + + """ return self._store.get_tensor(self._attr) def __eq__(self, __o: object) -> bool: + r"""Compares two :obj:`AttrView` objects by checking equality of their + :obj:`FeatureStore` references and :obj:`TensorAttr` attributes.""" if not isinstance(__o, AttrView): return False @@ -258,9 +318,6 @@ def view(self, attr: Optional[TensorAttr]) -> AttrView: # Python built-ins ######################################################## - def __repr__(self) -> str: - return f'{self.__class__.__name__}(backend={self.backend})' - def __setitem__(self, key: TensorAttr, value: FeatureTensorType): r"""Supports store[tensor_attr] = tensor.""" key = TensorAttr.cast(key) @@ -297,3 +354,6 @@ def __iter__(self): @abstractmethod def __len__(self): pass + + def __repr__(self) -> str: + return f'{self.__class__.__name__}(backend={self.backend})' From 832802a1ddd11b5bbc1f3a9ad3c0e063cefc915a Mon Sep 17 00:00:00 2001 From: Manan Shah Date: Thu, 28 Apr 2022 01:26:27 +0000 Subject: [PATCH 08/18] More cleanup for new API --- test/data/test_feature_store.py | 24 ++++++++++++++++-------- torch_geometric/data/feature_store.py | 24 ++++++++++++++---------- 2 files changed, 30 insertions(+), 18 deletions(-) diff --git a/test/data/test_feature_store.py b/test/data/test_feature_store.py index c69753c6af36..2bceb6bd7364 100644 --- a/test/data/test_feature_store.py +++ b/test/data/test_feature_store.py @@ -6,7 +6,6 @@ AttrView, FeatureStore, TensorAttr, - _field_status, ) from torch_geometric.typing import FeatureTensorType @@ -20,26 +19,36 @@ def __init__(self): @classmethod def key(cls, attr: TensorAttr): + r"""Define the key as (group_name, attr_name).""" return (attr.group_name or '', attr.attr_name or '') def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool: index = attr.index - if index is None or index is _field_status.UNSET: + + # Not set or None indices define the obvious index + if not attr.is_set('index') or index is None: index = torch.range(0, tensor.shape[0] - 1) + + # Store the index as a column self.store[MyFeatureStore.key(attr)] = torch.cat( (index.reshape(-1, 1), tensor), dim=1) + return True def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]: tensor = self.store.get(MyFeatureStore.key(attr), None) if tensor is None: return None - if attr.index is not None and attr.index is not _field_status.UNSET: - indices = torch.cat([(tensor[:, 0] == v).nonzero() - for v in attr.index]).reshape(1, -1)[0] - return torch.index_select(tensor[:, 1:], 0, indices) - return tensor[:, 1:] + # Not set or None indices return the whole tensor + if not attr.is_set('index') or attr.index is None: + return tensor[:, 1:] + + # Index into the tensor + indices = torch.cat([(tensor[:, 0] == v).nonzero() + for v in attr.index]).reshape(1, -1)[0] + + return torch.index_select(tensor[:, 1:], 0, indices) def _remove_tensor(self, attr: TensorAttr) -> bool: del self.store[MyFeatureStore.key(attr)] @@ -50,7 +59,6 @@ def __len__(self): def test_feature_store(): r"""Tests basic API and indexing functionality of a feature store.""" - store = MyFeatureStore() tensor = torch.Tensor([[0, 0, 0], [1, 1, 1], [2, 2, 2]]) index = torch.Tensor([0, 1, 2]) diff --git a/torch_geometric/data/feature_store.py b/torch_geometric/data/feature_store.py index 59ef3b014453..d9e9ec2e3719 100644 --- a/torch_geometric/data/feature_store.py +++ b/torch_geometric/data/feature_store.py @@ -196,9 +196,13 @@ def __repr__(self) -> str: class FeatureStore(MutableMapping): - def __init__(self, backend: Any): - r"""Initializes the feature store with a specified backend.""" + def __init__(self, backend: Any, attr_cls: Any = TensorAttr): + r"""Initializes the feature store with a specified backend. Implementor + classes can customize the ordering and require nature of their + :obj:`TensorAttr` tensor attributes by subclassing :obj:`TensorAttr` + and passing the subclass as `attr_cls`.""" self.backend = backend + self._attr_cls = attr_cls # Core (CRUD) ############################################################# @@ -222,7 +226,7 @@ def put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool: Returns: bool: whether insertion was successful. """ - attr = TensorAttr.cast(attr) + attr = self._attr_cls.cast(attr) return self._put_tensor(tensor, attr) @abstractmethod @@ -257,7 +261,7 @@ def to_type(tensor): torch.Tensor) else tensor return tensor - attr = TensorAttr.cast(attr) + attr = self._attr_cls.cast(attr) if isinstance(attr.index, slice) and (attr.index.start, attr.index.stop, attr.index.step) == (None, None, None): @@ -284,7 +288,7 @@ def remove_tensor(self, attr: TensorAttr) -> bool: Returns: bool: whether deletion was succesful. """ - attr = TensorAttr.cast(attr) + attr = self._attr_cls.cast(attr) self._remove_tensor(attr) def update_tensor(self, tensor: FeatureTensorType, @@ -305,7 +309,7 @@ def update_tensor(self, tensor: FeatureTensorType, Returns: bool: whether the update was succesful. """ - attr = TensorAttr.cast(attr) + attr = self._attr_cls.cast(attr) self.remove_tensor(attr) return self.put_tensor(tensor, attr) @@ -314,13 +318,13 @@ def update_tensor(self, tensor: FeatureTensorType, def view(self, attr: Optional[TensorAttr]) -> AttrView: r"""Returns an :obj:`AttrView` of the feature store, with the defined attributes set.""" - return AttrView(self, TensorAttr.cast(attr)) + return AttrView(self, self._attr_cls.cast(attr)) # Python built-ins ######################################################## def __setitem__(self, key: TensorAttr, value: FeatureTensorType): r"""Supports store[tensor_attr] = tensor.""" - key = TensorAttr.cast(key) + key = self._attr_cls.cast(key) self.put_tensor(value, key) def __getitem__(self, key: TensorAttr): @@ -338,14 +342,14 @@ def __getitem__(self, key: TensorAttr): """ # CastMixin will handle the case of key being a tuple or TensorAttr # object. - attr = TensorAttr.cast(key) + attr = self._attr_cls.cast(key) if attr.is_fully_specified(): return self.get_tensor(attr) return AttrView(self, attr) def __delitem__(self, key: TensorAttr): r"""Supports del store[tensor_attr].""" - key = TensorAttr.cast(key) + key = self._attr_cls.cast(key) self.remove_tensor(key) def __iter__(self): From 65e5e0b23fe00058dfa78a5e7645592cd3d82feb Mon Sep 17 00:00:00 2001 From: Manan Shah Date: Thu, 28 Apr 2022 01:50:39 +0000 Subject: [PATCH 09/18] Add override example --- test/data/test_feature_store.py | 47 ++++++++++++++++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/test/data/test_feature_store.py b/test/data/test_feature_store.py index 2bceb6bd7364..e909841d2647 100644 --- a/test/data/test_feature_store.py +++ b/test/data/test_feature_store.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from typing import Optional import torch @@ -6,6 +7,7 @@ AttrView, FeatureStore, TensorAttr, + _field_status, ) from torch_geometric.typing import FeatureTensorType @@ -27,7 +29,7 @@ def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool: # Not set or None indices define the obvious index if not attr.is_set('index') or index is None: - index = torch.range(0, tensor.shape[0] - 1) + index = torch.arange(0, tensor.shape[0]) # Store the index as a column self.store[MyFeatureStore.key(attr)] = torch.cat( @@ -57,6 +59,31 @@ def __len__(self): raise NotImplementedError +@dataclass +class MyTensorAttrNoGroupName(TensorAttr): + def __init__(self, attr_name=_field_status.UNSET, + index=_field_status.UNSET): + # Treat group_name as optional, and move it to the end + super().__init__(None, attr_name, index) + + +@dataclass +class MyFeatureStoreNoGroupName(MyFeatureStore): + # pylint: disable=super-init-not-called + def __init__(self): + FeatureStore.__init__(self, backend='test', + attr_cls=MyTensorAttrNoGroupName) + self.store = {} + + @classmethod + def key(cls, attr: TensorAttr): + r"""Define the key as (group_name, attr_name).""" + return attr.attr_name or '' + + def __len__(self): + raise NotImplementedError + + def test_feature_store(): r"""Tests basic API and indexing functionality of a feature store.""" store = MyFeatureStore() @@ -112,3 +139,21 @@ def test_feature_store(): assert store[group_name, attr_name, index] is None del store[group_name] assert store[group_name]() is None + + +def test_feature_store_override(): + store = MyFeatureStoreNoGroupName() + tensor = torch.Tensor([[0, 0, 0], [1, 1, 1], [2, 2, 2]]) + index = torch.Tensor([0, 1, 2]) + + attr_name = 'feat' + + # Only use attr_name and index, in that order + store[attr_name, index] = tensor + + # A few assertions to ensure group_name is not needed + assert isinstance(store[attr_name], AttrView) + assert torch.equal(store[attr_name, index], tensor) + assert torch.equal(store[attr_name][index], tensor) + assert torch.equal(store[attr_name][:], tensor) + assert torch.equal(store[attr_name, :], tensor) From 6fdf06267f835f9b925eb87ed308680de72c6dc9 Mon Sep 17 00:00:00 2001 From: Manan Shah Date: Thu, 28 Apr 2022 01:53:39 +0000 Subject: [PATCH 10/18] Fix --- test/data/test_feature_store.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/data/test_feature_store.py b/test/data/test_feature_store.py index e909841d2647..7d851065fae0 100644 --- a/test/data/test_feature_store.py +++ b/test/data/test_feature_store.py @@ -67,7 +67,6 @@ def __init__(self, attr_name=_field_status.UNSET, super().__init__(None, attr_name, index) -@dataclass class MyFeatureStoreNoGroupName(MyFeatureStore): # pylint: disable=super-init-not-called def __init__(self): From 94369fe4283ec46fc3e9f6ca1737ba9ea6e81a55 Mon Sep 17 00:00:00 2001 From: Manan Shah Date: Thu, 28 Apr 2022 04:18:25 +0000 Subject: [PATCH 11/18] Remove unnecessary properties --- torch_geometric/data/feature_store.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/torch_geometric/data/feature_store.py b/torch_geometric/data/feature_store.py index d9e9ec2e3719..b9935c69129f 100644 --- a/torch_geometric/data/feature_store.py +++ b/torch_geometric/data/feature_store.py @@ -100,16 +100,6 @@ def __init__(self, store: 'FeatureStore', attr: TensorAttr): self._store = store self._attr = attr - # Properties ############################################################## - - @property - def attr(self) -> TensorAttr: - return self._attr - - @property - def store(self) -> 'FeatureStore': - return self._store - # Python built-ins ######################################################## def __getattr__(self, key) -> 'AttrView': @@ -192,7 +182,8 @@ def __eq__(self, __o: object) -> bool: return id(self._store) == id(__o._store) and self._attr == __o._attr def __repr__(self) -> str: - return f"AttrView(store={self._store}, attr={self._attr})" + return (f'{self.__class__.__name__}(store={self._store}, ' + f'attr={self._attr})') class FeatureStore(MutableMapping): From e0c63c975972ef03bc771e0994185fb53a4d8d8d Mon Sep 17 00:00:00 2001 From: rusty1s Date: Thu, 28 Apr 2022 14:03:14 +0200 Subject: [PATCH 12/18] remove backend --- test/data/test_feature_store.py | 12 +++--------- torch_geometric/data/feature_store.py | 14 +++++++------- 2 files changed, 10 insertions(+), 16 deletions(-) diff --git a/test/data/test_feature_store.py b/test/data/test_feature_store.py index 7d851065fae0..bbc2de89eb74 100644 --- a/test/data/test_feature_store.py +++ b/test/data/test_feature_store.py @@ -13,15 +13,12 @@ class MyFeatureStore(FeatureStore): - r"""A basic feature store, does NOT implement all functionality of a - fully-fledged feature store. Only works for Torch tensors.""" def __init__(self): - super().__init__(backend='test') + super().__init__() self.store = {} @classmethod def key(cls, attr: TensorAttr): - r"""Define the key as (group_name, attr_name).""" return (attr.group_name or '', attr.attr_name or '') def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool: @@ -68,15 +65,12 @@ def __init__(self, attr_name=_field_status.UNSET, class MyFeatureStoreNoGroupName(MyFeatureStore): - # pylint: disable=super-init-not-called def __init__(self): - FeatureStore.__init__(self, backend='test', - attr_cls=MyTensorAttrNoGroupName) - self.store = {} + super().__init__() + self._attr_cls = MyTensorAttrNoGroupName @classmethod def key(cls, attr: TensorAttr): - r"""Define the key as (group_name, attr_name).""" return attr.attr_name or '' def __len__(self): diff --git a/torch_geometric/data/feature_store.py b/torch_geometric/data/feature_store.py index b9935c69129f..7506f769c34d 100644 --- a/torch_geometric/data/feature_store.py +++ b/torch_geometric/data/feature_store.py @@ -187,12 +187,12 @@ def __repr__(self) -> str: class FeatureStore(MutableMapping): - def __init__(self, backend: Any, attr_cls: Any = TensorAttr): - r"""Initializes the feature store with a specified backend. Implementor - classes can customize the ordering and require nature of their - :obj:`TensorAttr` tensor attributes by subclassing :obj:`TensorAttr` - and passing the subclass as `attr_cls`.""" - self.backend = backend + def __init__(self, attr_cls: Any = TensorAttr): + r"""Initializes the feature store. Implementor classes can customize + the ordering and require nature of their :obj:`TensorAttr` tensor + attributes by subclassing :class:`TensorAttr` and passing the subclass + as `attr_cls`.""" + super().__init__() self._attr_cls = attr_cls # Core (CRUD) ############################################################# @@ -351,4 +351,4 @@ def __len__(self): pass def __repr__(self) -> str: - return f'{self.__class__.__name__}(backend={self.backend})' + return f'{self.__class__.__name__}()' From 0fd6dd2b1cba29f059cc74e424bd572ae93f1265 Mon Sep 17 00:00:00 2001 From: Manan Shah Date: Thu, 28 Apr 2022 23:08:20 +0000 Subject: [PATCH 13/18] Updates --- torch_geometric/data/feature_store.py | 124 +++++++++++++++----------- 1 file changed, 72 insertions(+), 52 deletions(-) diff --git a/torch_geometric/data/feature_store.py b/torch_geometric/data/feature_store.py index b9935c69129f..cd0a465e2fe6 100644 --- a/torch_geometric/data/feature_store.py +++ b/torch_geometric/data/feature_store.py @@ -1,16 +1,16 @@ r""" -This class defines the abstraction for a Graph feature store. The goal of a -feature store is to abstract away all node and edge feature memory management -so that varying implementations can allow for independent scale-out. +This class defines the abstraction for a backend-agnostic feature store. The +goal of a feature store is to abstract away all node and edge feature memory +management so that varying implementations can allow for independent scale-out. This particular feature store abstraction makes a few key assumptions: - * The features we care about storing are graph node and edge features. To - this end, the attributes that the feature store supports include a + * The features we care about storing are node and edge features of a graph. + To this end, the attributes that the feature store supports include a group_name (e.g. a heterogeneous node name, a heterogeneous edge type, etc.), an attr_name (which defines the name of the feature tensor, - e.g. `feat`, `discrete_feat`, etc.), and an index. + e.g. `x`, `edge_attr`, etc.), and an index. * A feature can be uniquely identified from any associated attributes - specified in :obj:`TensorAttr`. + specified in :class:`TensorAttr`. It is the job of a feature store implementor class to handle these assumptions properly. For example, a simple in-memory feature store implementation may @@ -21,6 +21,7 @@ Major TODOs for future implementation: * Async `put` and `get` functionality """ +import copy from abc import abstractmethod from collections.abc import MutableMapping from dataclasses import dataclass @@ -35,7 +36,9 @@ _field_status = Enum("FieldStatus", "UNSET") -IndexType = Union[FeatureTensorType, slice] +# We allow indexing with a tensor, numpy array, Python slicing, or a single +# integer index. +IndexType = Union[torch.Tensor, np.ndarray, slice, int] @dataclass @@ -50,15 +53,13 @@ class TensorAttr(CastMixin): TensorAttr.__init__. """ - # The type of the nodes that the tensor corresponds to (may be used for - # hetereogeneous graphs) - group_name: Optional[str] = _field_status.UNSET + # The group name that the tensor corresponds to. Defaults to None. + group_name: Optional[str] = None - # The name of the feature tensor (may be used if there are multiple - # different feature tensors for the same node index) - attr_name: Optional[str] = _field_status.UNSET + # The name of the tensor within its group. Defaults to None. + attr_name: Optional[str] = None - # The node indices the rows of the tensor correspond to + # The node indices the rows of the tensor correspond to. Defaults to UNSET. index: Optional[IndexType] = _field_status.UNSET # Convenience methods ##################################################### @@ -85,24 +86,38 @@ def update(self, attr: 'TensorAttr'): class AttrView(CastMixin): - r"""Defines a view of a :obj:`FeatureStore` that is obtained from a + r"""Defines a view of a :class:`FeatureStore` that is obtained from a specification of attributes on the feature store. The view stores a - reference to the backing feature store as well as a :obj:`TensorAttr` + reference to the backing feature store as well as a :class:`TensorAttr` object that represents the view's state. Users can create views either using the :obj:`AttrView` constructor, - :obj:`FeatureStore.view`, or by incompletely indexing a feature store. - """ - _store: 'FeatureStore' - _attr: TensorAttr + :obj:`FeatureStore.view`, or by incompletely indexing a feature store. For + example, the following calls all create views: + + .. code-block:: python + + store[group_name] + store[group_name].feat + store[group_name, feat] + + While the following calls all materialize those views and produce tensors + by either calling the view or fully-specifying the view: + + .. code-block:: python + + store[group_name]() + store[group_name].feat[index] + store[group_name, feat][index] + """ def __init__(self, store: 'FeatureStore', attr: TensorAttr): - self._store = store - self._attr = attr + self.__dict__['_store'] = store + self.__dict__['_attr'] = attr # Python built-ins ######################################################## - def __getattr__(self, key) -> 'AttrView': + def __getattr__(self, key: str) -> 'AttrView': r"""Sets the attr_name field of the backing :obj:`TensorAttr` object to the attribute. This allows for :obj:`AttrView` to be indexed by different values of attr_name. In particular, for a feature store that @@ -113,13 +128,11 @@ def __getattr__(self, key) -> 'AttrView': store[group_name].feat[:] """ - if key in ['_attr', '_store']: - return super(AttrView, self).__getattribute__(key) - - self._attr.attr_name = key - if self._attr.is_fully_specified(): - return self._store.get_tensor(self._attr) - return self + out = copy.copy(self) + out._attr.attr_name = key + if out._attr.is_fully_specified(): + return out._store.get_tensor(out._attr) + return out def __setattr__(self, key, value): r"""Supports attribute assignment to the backing :obj:`TensorAttr` of @@ -133,10 +146,7 @@ def __setattr__(self, key, value): view.index = torch.Tensor([1, 2, 3]) """ - if key in ['_attr', '_store']: - return super(AttrView, self).__setattr__(key, value) - - TensorAttr.__setattr__(self._attr, key, value) + setattr(self._attr, key, value) def __getitem__( self, @@ -153,10 +163,11 @@ def __getitem__( store[group_name, attr_name][:] """ - self._attr.index = index - if self._attr.is_fully_specified(): - return self._store.get_tensor(self._attr) - return self + out = copy.copy(self) + out._attr.index = index + if out._attr.is_fully_specified(): + return out._store.get_tensor(out._attr) + return out def __call__(self) -> FeatureTensorType: r"""Supports :obj:`AttrView` as a callable to force retrieval from @@ -173,13 +184,19 @@ def __call__(self) -> FeatureTensorType: """ return self._store.get_tensor(self._attr) + def __copy__(self): + out = self.__class__.__new__(self.__class__) + for key, value in self.__dict__.items(): + out.__dict__[key] = value + return out + def __eq__(self, __o: object) -> bool: r"""Compares two :obj:`AttrView` objects by checking equality of their :obj:`FeatureStore` references and :obj:`TensorAttr` attributes.""" if not isinstance(__o, AttrView): return False - return id(self._store) == id(__o._store) and self._attr == __o._attr + return self._store == __o._store and self._attr == __o._attr def __repr__(self) -> str: return (f'{self.__class__.__name__}(store={self._store}, ' @@ -202,7 +219,7 @@ def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool: r"""Implemented by :obj:`FeatureStore` subclasses.""" pass - def put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool: + def put_tensor(self, tensor: FeatureTensorType, *args, **kwargs) -> bool: r"""Synchronously adds a :obj:`FeatureTensorType` object to the feature store. @@ -217,7 +234,7 @@ def put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool: Returns: bool: whether insertion was successful. """ - attr = self._attr_cls.cast(attr) + attr = self._attr_cls.cast(*args, **kwargs) return self._put_tensor(tensor, attr) @abstractmethod @@ -225,10 +242,10 @@ def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]: r"""Implemented by :obj:`FeatureStore` subclasses.""" pass - def get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]: + def get_tensor(self, *args, **kwargs) -> Optional[FeatureTensorType]: r"""Synchronously obtains a :obj:`FeatureTensorType` object from the feature store. Feature store implementors guarantee that the call - get_tensor(put_tensor(tensor, attr), attr) = tensor. + get_tensor(put_tensor(tensor, attr), attr) = tensor holds. Args: attr (TensorAttr): any relevant tensor attributes that correspond @@ -252,7 +269,7 @@ def to_type(tensor): torch.Tensor) else tensor return tensor - attr = self._attr_cls.cast(attr) + attr = self._attr_cls.cast(*args, **kwargs) if isinstance(attr.index, slice) and (attr.index.start, attr.index.stop, attr.index.step) == (None, None, None): @@ -265,7 +282,7 @@ def _remove_tensor(self, attr: TensorAttr) -> bool: r"""Implemented by :obj:`FeatureStore` subclasses.""" pass - def remove_tensor(self, attr: TensorAttr) -> bool: + def remove_tensor(self, *args, **kwargs) -> bool: r"""Removes a :obj:`FeatureTensorType` object from the feature store. Args: @@ -279,11 +296,11 @@ def remove_tensor(self, attr: TensorAttr) -> bool: Returns: bool: whether deletion was succesful. """ - attr = self._attr_cls.cast(attr) + attr = self._attr_cls.cast(*args, **kwargs) self._remove_tensor(attr) - def update_tensor(self, tensor: FeatureTensorType, - attr: TensorAttr) -> bool: + def update_tensor(self, tensor: FeatureTensorType, *args, + **kwargs) -> bool: r"""Updates a :obj:`FeatureTensorType` object with a new value. implementor classes can choose to define more efficient update methods; the default performs a removal and insertion. @@ -300,16 +317,16 @@ def update_tensor(self, tensor: FeatureTensorType, Returns: bool: whether the update was succesful. """ - attr = self._attr_cls.cast(attr) + attr = self._attr_cls.cast(*args, **kwargs) self.remove_tensor(attr) return self.put_tensor(tensor, attr) # :obj:`AttrView` methods ################################################# - def view(self, attr: Optional[TensorAttr]) -> AttrView: + def view(self, *args, **kwargs) -> AttrView: r"""Returns an :obj:`AttrView` of the feature store, with the defined attributes set.""" - return AttrView(self, self._attr_cls.cast(attr)) + return AttrView(self, self._attr_cls.cast(*args, **kwargs)) # Python built-ins ######################################################## @@ -346,6 +363,9 @@ def __delitem__(self, key: TensorAttr): def __iter__(self): raise NotImplementedError + def __eq__(self, __o: object) -> bool: + return id(self) == id(__o) + @abstractmethod def __len__(self): pass From 385512e865ae18160ad2021826d526df5dcf75b3 Mon Sep 17 00:00:00 2001 From: Manan Shah Date: Fri, 29 Apr 2022 00:04:23 +0000 Subject: [PATCH 14/18] More updates --- test/data/test_feature_store.py | 19 ++-- torch_geometric/data/feature_store.py | 120 ++++++++++++++++++++------ 2 files changed, 107 insertions(+), 32 deletions(-) diff --git a/test/data/test_feature_store.py b/test/data/test_feature_store.py index bbc2de89eb74..4fc94ee94300 100644 --- a/test/data/test_feature_store.py +++ b/test/data/test_feature_store.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from typing import Optional +import pytest import torch from torch_geometric.data.feature_store import ( @@ -25,7 +26,7 @@ def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool: index = attr.index # Not set or None indices define the obvious index - if not attr.is_set('index') or index is None: + if index is None: index = torch.arange(0, tensor.shape[0]) # Store the index as a column @@ -40,7 +41,7 @@ def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]: return None # Not set or None indices return the whole tensor - if not attr.is_set('index') or attr.index is None: + if attr.index is None: return tensor[:, 1:] # Index into the tensor @@ -92,17 +93,17 @@ def test_feature_store(): store.put_tensor(tensor, attr) assert torch.equal(store.get_tensor(attr), tensor) assert torch.equal( - store.get_tensor((group_name, attr_name, torch.Tensor([0, 2]))), + store.get_tensor(group_name, attr_name, torch.Tensor([0, 2])), tensor[[0, 2]], ) - assert store.get_tensor(TensorAttr(index=index)) is None - store.remove_tensor(TensorAttr(group_name, attr_name)) + assert store.get_tensor(None, None, index) is None + store.remove_tensor(group_name, attr_name, None) assert store.get_tensor(attr) is None # Views view = store.view(TensorAttr(group_name=group_name)) view.attr_name = attr_name - view.index = index + view['index'] = index assert view == AttrView(store, TensorAttr(group_name, attr_name, index)) # Indexing @@ -114,7 +115,13 @@ def test_feature_store(): assert torch.equal(store[group_name, attr_name, index], tensor) assert torch.equal(store[group_name, attr_name, None], tensor) assert torch.equal(store[group_name, attr_name, :], tensor) + assert torch.equal(store[group_name][attr_name][:], tensor) assert torch.equal(store[group_name].feat[:], tensor) + assert torch.equal(store.view().A.feat[:], tensor) + + with pytest.raises(AttributeError) as exc_info: + _ = store.view(group_name=group_name, index=None).feat.A + print(exc_info) # Partially-specified forms, which produce an AttrView object assert store[group_name] == store.view(TensorAttr(group_name=group_name)) diff --git a/torch_geometric/data/feature_store.py b/torch_geometric/data/feature_store.py index fca885b6788a..2a6c9916fee7 100644 --- a/torch_geometric/data/feature_store.py +++ b/torch_geometric/data/feature_store.py @@ -54,10 +54,10 @@ class TensorAttr(CastMixin): """ # The group name that the tensor corresponds to. Defaults to None. - group_name: Optional[str] = None + group_name: Optional[str] = _field_status.UNSET # The name of the tensor within its group. Defaults to None. - attr_name: Optional[str] = None + attr_name: Optional[str] = _field_status.UNSET # The node indices the rows of the tensor correspond to. Defaults to UNSET. index: Optional[IndexType] = _field_status.UNSET @@ -76,6 +76,13 @@ def is_fully_specified(self): for field in self.__dataclass_fields__ ]) + def fully_specify(self): + r"""Sets all UNSET fields to None.""" + for field in self.__dataclass_fields__: + if getattr(self, field) == _field_status.UNSET: + setattr(self, field, None) + return self + def update(self, attr: 'TensorAttr'): r"""Updates an :obj:`TensorAttr` with set attributes from another :obj:`TensorAttr`.""" @@ -115,25 +122,58 @@ def __init__(self, store: 'FeatureStore', attr: TensorAttr): self.__dict__['_store'] = store self.__dict__['_attr'] = attr - # Python built-ins ######################################################## + # Advanced indexing ####################################################### - def __getattr__(self, key: str) -> 'AttrView': - r"""Sets the attr_name field of the backing :obj:`TensorAttr` object to - the attribute. This allows for :obj:`AttrView` to be indexed by - different values of attr_name. In particular, for a feature store that - has `feat` as an `attr_name`, the following code indexes into `feat`: + def __getattr__(self, key: Any) -> 'AttrView': + r"""Sets the first unset field of the backing :obj:`TensorAttr` object + to the attribute. This allows for :obj:`AttrView` to be indexed by + different values of attributes, in order. In particular, for a feature + store that we want to index by `group_name` group and `attr_name` attr, + the following code will do so: .. code-block:: python - store[group_name].feat[:] + store[group, attr] + store[group].attr + store.group.attr """ out = copy.copy(self) - out._attr.attr_name = key + + # First attribute that is UNSET + attr_name = None + for field in out._attr.__dataclass_fields__: + if getattr(out._attr, field) == _field_status.UNSET: + attr_name = field + break + + if attr_name is None: + raise AttributeError(f"Cannot access attribute {key} on view " + f"{out} as all attributes have already been " + f"set in this view.") + + setattr(out._attr, attr_name, key) if out._attr.is_fully_specified(): return out._store.get_tensor(out._attr) return out + def __getitem__(self, key: Any) -> Union['AttrView', FeatureTensorType]: + r"""Sets the first unset field of the backing :obj:`TensorAttr` object + to the attribute via indexing. This allows for :obj:`AttrView` to be + indexed by different values of attributes, in order. In particular, for + a feature store that we want to index by `group_name` group and + `attr_name` attr, the following code will do so: + + .. code-block:: python + + store[group, attr] + store[group][attr] + + """ + return self.__getattr__(key) + + # Setting attributes ###################################################### + def __setattr__(self, key, value): r"""Supports attribute assignment to the backing :obj:`TensorAttr` of an :obj:`AttrView`. This allows for :obj:`AttrView` objects to set @@ -146,28 +186,28 @@ def __setattr__(self, key, value): view.index = torch.Tensor([1, 2, 3]) """ + if key not in self._attr.__dataclass_fields__: + raise ValueError(f"Attempted to set nonexistent attribute {key} " + f"(acceptable attributes are " + f"{self._attr.__dataclass_fields__}).") + setattr(self._attr, key, value) - def __getitem__( - self, - index: IndexType, - ) -> Union['AttrView', FeatureTensorType]: - r"""Supports indexing the backing :obj:`TensorAttr` object by an - index or a slice. If the index operation results in a fully-specified - :obj:`AttrView`, a Tensor is returned. Otherwise, the :obj:`AttrView` - object is returned. The following operation returns a Tensor object - as a result of the index specification: + def __setitem__(self, key, value): + r"""Supports attribute assignment to the backing :obj:`TensorAttr` of + an :obj:`AttrView` via indexing. This allows for :obj:`AttrView` + objects to set their backing attribute values. In particular, the + following operation sets the `index` of an :obj:`AttrView`: .. code-block:: python - store[group_name, attr_name][:] + view = store.view(TensorAttr(group_name)) + view['index'] = torch.Tensor([1, 2, 3]) """ - out = copy.copy(self) - out._attr.index = index - if out._attr.is_fully_specified(): - return out._store.get_tensor(out._attr) - return out + self.__setattr__(key, value) + + # Miscellaneous built-ins ################################################# def __call__(self) -> FeatureTensorType: r"""Supports :obj:`AttrView` as a callable to force retrieval from @@ -182,7 +222,10 @@ def __call__(self) -> FeatureTensorType: store[group_name, attr_name]() """ - return self._store.get_tensor(self._attr) + # Set all UNSET values to None if forced execution + out = copy.copy(self) + out._attr.fully_specify() + return out._store.get_tensor(out._attr) def __copy__(self): out = self.__class__.__new__(self.__class__) @@ -235,6 +278,11 @@ def put_tensor(self, tensor: FeatureTensorType, *args, **kwargs) -> bool: bool: whether insertion was successful. """ attr = self._attr_cls.cast(*args, **kwargs) + if not attr.is_fully_specified(): + raise ValueError(f"The input TensorAttr {attr} is not fully " + f"specified. Please fully specify the input " + f"by specifying all UNSET fields.") + return self._put_tensor(tensor, attr) @abstractmethod @@ -275,6 +323,11 @@ def to_type(tensor): attr.index.step) == (None, None, None): attr.index = None + if not attr.is_fully_specified(): + raise ValueError(f"The input TensorAttr {attr} is not fully " + f"specified. Please fully specify the input " + f"by specifying all UNSET fields.") + return to_type(self._get_tensor(attr)) @abstractmethod @@ -297,6 +350,11 @@ def remove_tensor(self, *args, **kwargs) -> bool: bool: whether deletion was succesful. """ attr = self._attr_cls.cast(*args, **kwargs) + if not attr.is_fully_specified(): + raise ValueError(f"The input TensorAttr {attr} is not fully " + f"specified. Please fully specify the input " + f"by specifying all UNSET fields.") + self._remove_tensor(attr) def update_tensor(self, tensor: FeatureTensorType, *args, @@ -318,6 +376,11 @@ def update_tensor(self, tensor: FeatureTensorType, *args, bool: whether the update was succesful. """ attr = self._attr_cls.cast(*args, **kwargs) + if not attr.is_fully_specified(): + raise ValueError(f"The input TensorAttr {attr} is not fully " + f"specified. Please fully specify the input " + f"by specifying all UNSET fields.") + self.remove_tensor(attr) return self.put_tensor(tensor, attr) @@ -333,6 +396,10 @@ def view(self, *args, **kwargs) -> AttrView: def __setitem__(self, key: TensorAttr, value: FeatureTensorType): r"""Supports store[tensor_attr] = tensor.""" key = self._attr_cls.cast(key) + + # We need to fully specify the key for __setitem__ as it does not make + # sense to work with a view here. + key.fully_specify() self.put_tensor(value, key) def __getitem__(self, key: TensorAttr): @@ -358,6 +425,7 @@ def __getitem__(self, key: TensorAttr): def __delitem__(self, key: TensorAttr): r"""Supports del store[tensor_attr].""" key = self._attr_cls.cast(key) + key.fully_specify() self.remove_tensor(key) def __iter__(self): From a5c4f1c48bab91060089094368fa5793e03ea069 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Fri, 29 Apr 2022 13:36:10 +0200 Subject: [PATCH 15/18] pass --- torch_geometric/data/feature_store.py | 140 ++++++++++++-------------- 1 file changed, 66 insertions(+), 74 deletions(-) diff --git a/torch_geometric/data/feature_store.py b/torch_geometric/data/feature_store.py index 2a6c9916fee7..05dc9a78e1d2 100644 --- a/torch_geometric/data/feature_store.py +++ b/torch_geometric/data/feature_store.py @@ -1,16 +1,15 @@ r""" This class defines the abstraction for a backend-agnostic feature store. The -goal of a feature store is to abstract away all node and edge feature memory +goal of the feature store is to abstract away all node and edge feature memory management so that varying implementations can allow for independent scale-out. This particular feature store abstraction makes a few key assumptions: - * The features we care about storing are node and edge features of a graph. - To this end, the attributes that the feature store supports include a - group_name (e.g. a heterogeneous node name, a heterogeneous edge type, - etc.), an attr_name (which defines the name of the feature tensor, - e.g. `x`, `edge_attr`, etc.), and an index. - * A feature can be uniquely identified from any associated attributes - specified in :class:`TensorAttr`. +* The features we care about storing are node and edge features of a graph. + To this end, the attributes that the feature store supports include a + `group_name` (e.g. a heterogeneous node name or a heterogeneous edge type), + an `attr_name` (e.g. `x` or `edge_attr`), and an index. +* A feature can be uniquely identified from any associated attributes specified + in `TensorAttr`. It is the job of a feature store implementor class to handle these assumptions properly. For example, a simple in-memory feature store implementation may @@ -43,14 +42,14 @@ @dataclass class TensorAttr(CastMixin): - r"""Defines the attributes of a :obj:`FeatureStore` tensor; in particular, + r"""Defines the attributes of a class:`FeatureStore` tensor; in particular, all the parameters necessary to uniquely identify a tensor from the feature store. Note that the order of the attributes is important; this is the order in which attributes must be provided for indexing calls. Feature store implementor classes can define a different ordering by overriding - TensorAttr.__init__. + :meth:`TensorAttr.__init__`. """ # The group name that the tensor corresponds to. Defaults to None. @@ -64,32 +63,28 @@ class TensorAttr(CastMixin): # Convenience methods ##################################################### - def is_set(self, attr): + def is_set(self, key: str) -> bool: r"""Whether an attribute is set in :obj:`TensorAttr`.""" - assert attr in self.__dataclass_fields__ - return getattr(self, attr) != _field_status.UNSET + assert key in self.__dataclass_fields__ + return getattr(self, key) != _field_status.UNSET - def is_fully_specified(self): + def is_fully_specified(self) -> bool: r"""Whether the :obj:`TensorAttr` has no unset fields.""" - return all([ - getattr(self, field) != _field_status.UNSET - for field in self.__dataclass_fields__ - ]) + return all([self.is_set(key) for key in self.__dataclass_fields__]) def fully_specify(self): - r"""Sets all UNSET fields to None.""" - for field in self.__dataclass_fields__: - if getattr(self, field) == _field_status.UNSET: - setattr(self, field, None) + r"""Sets all :obj:`UNSET` fields to :obj:`None`.""" + for key in self.__dataclass_fields__: + if not self.is_set(key): + setattr(self, key, None) return self def update(self, attr: 'TensorAttr'): - r"""Updates an :obj:`TensorAttr` with set attributes from another - :obj:`TensorAttr`.""" - for field in self.__dataclass_fields__: - val = getattr(attr, field) - if val != _field_status.UNSET: - setattr(self, field, val) + r"""Updates an :class:`TensorAttr` with set attributes from another + :class:`TensorAttr`.""" + for key in self.__dataclass_fields__: + if attr.is_set(key): + setattr(self, key, getattr(attr, key)) class AttrView(CastMixin): @@ -98,9 +93,9 @@ class AttrView(CastMixin): reference to the backing feature store as well as a :class:`TensorAttr` object that represents the view's state. - Users can create views either using the :obj:`AttrView` constructor, - :obj:`FeatureStore.view`, or by incompletely indexing a feature store. For - example, the following calls all create views: + Users can create views either using the :class:`AttrView` constructor, + :meth:`FeatureStore.view`, or by incompletely indexing a feature store. + For example, the following calls all create views: .. code-block:: python @@ -116,7 +111,6 @@ class AttrView(CastMixin): store[group_name]() store[group_name].feat[index] store[group_name, feat][index] - """ def __init__(self, store: 'FeatureStore', attr: TensorAttr): self.__dict__['_store'] = store @@ -124,11 +118,11 @@ def __init__(self, store: 'FeatureStore', attr: TensorAttr): # Advanced indexing ####################################################### - def __getattr__(self, key: Any) -> 'AttrView': - r"""Sets the first unset field of the backing :obj:`TensorAttr` object - to the attribute. This allows for :obj:`AttrView` to be indexed by + def __getattr__(self, key: Any) -> Union['AttrView', FeatureTensorType]: + r"""Sets the first unset field of the backing :class:`TensorAttr` object + to the attribute. This allows for :class:`AttrView` to be indexed by different values of attributes, in order. In particular, for a feature - store that we want to index by `group_name` group and `attr_name` attr, + store that we want to index by :obj:`group_name` and :obj:`attr_name`, the following code will do so: .. code-block:: python @@ -136,33 +130,34 @@ def __getattr__(self, key: Any) -> 'AttrView': store[group, attr] store[group].attr store.group.attr - """ out = copy.copy(self) - # First attribute that is UNSET - attr_name = None + # Find the first attribute name that is UNSET: + attr_name: Optional[str] = None for field in out._attr.__dataclass_fields__: if getattr(out._attr, field) == _field_status.UNSET: attr_name = field break if attr_name is None: - raise AttributeError(f"Cannot access attribute {key} on view " - f"{out} as all attributes have already been " - f"set in this view.") + raise AttributeError(f"Cannot access attribute '{key}' on view " + f"'{out}' as all attributes have already " + f"been set in this view") setattr(out._attr, attr_name, key) + if out._attr.is_fully_specified(): return out._store.get_tensor(out._attr) + return out def __getitem__(self, key: Any) -> Union['AttrView', FeatureTensorType]: - r"""Sets the first unset field of the backing :obj:`TensorAttr` object - to the attribute via indexing. This allows for :obj:`AttrView` to be + r"""Sets the first unset field of the backing :class:`TensorAttr` object + to the attribute via indexing. This allows for :class:`AttrView` to be indexed by different values of attributes, in order. In particular, for - a feature store that we want to index by `group_name` group and - `attr_name` attr, the following code will do so: + a feature store that we want to index by :obj:`group_name` and + :obj:`attr_name`, the following code will do so: .. code-block:: python @@ -174,72 +169,69 @@ def __getitem__(self, key: Any) -> Union['AttrView', FeatureTensorType]: # Setting attributes ###################################################### - def __setattr__(self, key, value): - r"""Supports attribute assignment to the backing :obj:`TensorAttr` of - an :obj:`AttrView`. This allows for :obj:`AttrView` objects to set + def __setattr__(self, key: str, value: Any): + r"""Supports attribute assignment to the backing :class:`TensorAttr` of + an :class:`AttrView`. This allows for :class:`AttrView` objects to set their backing attribute values. In particular, the following operation - sets the `index` of an :obj:`AttrView`: + sets the :obj:`index` of an :class:`AttrView`: .. code-block:: python - view = store.view(TensorAttr(group_name)) - view.index = torch.Tensor([1, 2, 3]) - + view = store.view(group_name) + view.index = torch.tensor([1, 2, 3]) """ if key not in self._attr.__dataclass_fields__: - raise ValueError(f"Attempted to set nonexistent attribute {key} " + raise ValueError(f"Attempted to set nonexistent attribute '{key}' " f"(acceptable attributes are " - f"{self._attr.__dataclass_fields__}).") + f"{self._attr.__dataclass_fields__})") setattr(self._attr, key, value) - def __setitem__(self, key, value): - r"""Supports attribute assignment to the backing :obj:`TensorAttr` of - an :obj:`AttrView` via indexing. This allows for :obj:`AttrView` + def __setitem__(self, key: str, value: Any): + r"""Supports attribute assignment to the backing :class:`TensorAttr` of + an :class:`AttrView` via indexing. This allows for :class:`AttrView` objects to set their backing attribute values. In particular, the - following operation sets the `index` of an :obj:`AttrView`: + following operation sets the `index` of an :class:`AttrView`: .. code-block:: python view = store.view(TensorAttr(group_name)) view['index'] = torch.Tensor([1, 2, 3]) - """ self.__setattr__(key, value) # Miscellaneous built-ins ################################################# def __call__(self) -> FeatureTensorType: - r"""Supports :obj:`AttrView` as a callable to force retrieval from + r"""Supports :class:`AttrView` as a callable to force retrieval from the currently specified attributes. In particular, this passes the - current :obj:`TensorAttr` object to a GET call, regardless of whether - all attributes have been specified. It returns the result of this - call. In particular, the following operation returns a Tensor by - performing a GET operation on the backing feature store: + current :class:`TensorAttr` object to a GET call, regardless of whether + all attributes have been specified. It returns the result of this call. + In particular, the following operation returns a tensor by performing a + GET operation on the backing feature store: .. code-block:: python store[group_name, attr_name]() - """ - # Set all UNSET values to None if forced execution + # Set all UNSET values to None: out = copy.copy(self) out._attr.fully_specify() return out._store.get_tensor(out._attr) - def __copy__(self): + def __copy__(self) -> 'AttrView': out = self.__class__.__new__(self.__class__) for key, value in self.__dict__.items(): out.__dict__[key] = value + out.__dict__['_attr'] = copy.copy(out.__dict__['_attr']) return out - def __eq__(self, __o: object) -> bool: - r"""Compares two :obj:`AttrView` objects by checking equality of their - :obj:`FeatureStore` references and :obj:`TensorAttr` attributes.""" - if not isinstance(__o, AttrView): + def __eq__(self, obj: Any) -> bool: + r"""Compares two :class:`AttrView` objects by checking equality of their + :class:`FeatureStore` references and :class:`TensorAttr` attributes.""" + if not isinstance(obj, AttrView): return False - - return self._store == __o._store and self._attr == __o._attr + return self._store == obj._store and self._attr == obj._attr def __repr__(self) -> str: return (f'{self.__class__.__name__}(store={self._store}, ' From f94920773d78ef3524a95918a6592da606895b14 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Fri, 29 Apr 2022 13:51:32 +0200 Subject: [PATCH 16/18] pass --- torch_geometric/data/feature_store.py | 164 +++++++++++++------------- 1 file changed, 80 insertions(+), 84 deletions(-) diff --git a/torch_geometric/data/feature_store.py b/torch_geometric/data/feature_store.py index 05dc9a78e1d2..7e3f8a82612f 100644 --- a/torch_geometric/data/feature_store.py +++ b/torch_geometric/data/feature_store.py @@ -241,9 +241,9 @@ def __repr__(self) -> str: class FeatureStore(MutableMapping): def __init__(self, attr_cls: Any = TensorAttr): r"""Initializes the feature store. Implementor classes can customize - the ordering and require nature of their :obj:`TensorAttr` tensor + the ordering and required nature of their :class:`TensorAttr` tensor attributes by subclassing :class:`TensorAttr` and passing the subclass - as `attr_cls`.""" + as :obj:`attr_cls`.""" super().__init__() self._attr_cls = attr_cls @@ -251,171 +251,167 @@ def __init__(self, attr_cls: Any = TensorAttr): @abstractmethod def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool: - r"""Implemented by :obj:`FeatureStore` subclasses.""" + r"""To be implemented by :class:`FeatureStore` subclasses.""" pass def put_tensor(self, tensor: FeatureTensorType, *args, **kwargs) -> bool: - r"""Synchronously adds a :obj:`FeatureTensorType` object to the feature - store. + r"""Synchronously adds a :class:`FeatureTensorType` object to the + feature store. Args: - tensor (FeatureTensorType): the features to be added. - attr (TensorAttr): any relevant tensor attributes that correspond - to the feature tensor. See the :obj:`TensorAttr` documentation - for required and optional attributes. It is the job of - implementations of a FeatureStore to store this metadata in a - meaningful way that allows for tensor retrieval from a - :obj:`TensorAttr` object. + tensor (FeatureTensorType): The feature tensor to be added. + **attr (TensorAttr): Any relevant tensor attributes that correspond + to the feature tensor. See the :class:`TensorAttr` + documentation for required and optional attributes. It is the + job of implementations of a :class:`FeatureStore` to store this + metadata in a meaningful way that allows for tensor retrieval + from a :class:`TensorAttr` object. + Returns: - bool: whether insertion was successful. + bool: Whether insertion was successful. """ attr = self._attr_cls.cast(*args, **kwargs) if not attr.is_fully_specified(): - raise ValueError(f"The input TensorAttr {attr} is not fully " - f"specified. Please fully specify the input " - f"by specifying all UNSET fields.") - + raise ValueError(f"The input TensorAttr '{attr}' is not fully " + f"specified. Please fully specify the input by " + f"specifying all 'UNSET' fields") return self._put_tensor(tensor, attr) @abstractmethod def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]: - r"""Implemented by :obj:`FeatureStore` subclasses.""" + r"""To be implemented by :class:`FeatureStore` subclasses.""" pass def get_tensor(self, *args, **kwargs) -> Optional[FeatureTensorType]: - r"""Synchronously obtains a :obj:`FeatureTensorType` object from the + r"""Synchronously obtains a :class:`FeatureTensorType` object from the feature store. Feature store implementors guarantee that the call - get_tensor(put_tensor(tensor, attr), attr) = tensor holds. + :obj:`get_tensor(put_tensor(tensor, attr), attr) = tensor` holds. Args: - attr (TensorAttr): any relevant tensor attributes that correspond - to the tensor to obtain. See :obj:`TensorAttr` documentation - for required and optional attributes. It is the job of - implementations of a FeatureStore to store this metadata in a - meaningful way that allows for tensor retrieval from a - :obj:`TensorAttr` object. + **attr (TensorAttr): Any relevant tensor attributes that correspond + to the feature tensor. See the :class:`TensorAttr` + documentation for required and optional attributes. It is the + job of implementations of a :class:`FeatureStore` to store this + metadata in a meaningful way that allows for tensor retrieval + from a :class:`TensorAttr` object. + Returns: - FeatureTensorType, optional: a tensor of the same type as the - index, or None if no tensor was found. + FeatureTensorType, optional: a Tensor of the same type as the + index, or :obj:`None` if no tensor was found. """ - def to_type(tensor): + def to_type(tensor: FeatureTensorType) -> FeatureTensorType: if tensor is None: return None - if isinstance(attr.index, torch.Tensor): - return torch.from_numpy(tensor) if isinstance( - tensor, np.ndarray) else tensor - if isinstance(attr.index, np.ndarray): - return tensor.numpy() if isinstance(tensor, - torch.Tensor) else tensor + if (isinstance(attr.index, torch.Tensor) + and isinstance(tensor, np.ndarray)): + return torch.from_numpy(tensor) + if (isinstance(attr.index, np.ndarray) + and isinstance(tensor, torch.Tensor)): + return tensor.numpy() return tensor attr = self._attr_cls.cast(*args, **kwargs) - if isinstance(attr.index, - slice) and (attr.index.start, attr.index.stop, - attr.index.step) == (None, None, None): - attr.index = None + if isinstance(attr.index, slice): + if attr.index.start == attr.index.stop == attr.index.step is None: + attr.index = None if not attr.is_fully_specified(): - raise ValueError(f"The input TensorAttr {attr} is not fully " - f"specified. Please fully specify the input " - f"by specifying all UNSET fields.") + raise ValueError(f"The input TensorAttr '{attr}' is not fully " + f"specified. Please fully specify the input by " + f"specifying all 'UNSET' fields.") return to_type(self._get_tensor(attr)) @abstractmethod def _remove_tensor(self, attr: TensorAttr) -> bool: - r"""Implemented by :obj:`FeatureStore` subclasses.""" + r"""To be implemented by :obj:`FeatureStore` subclasses.""" pass def remove_tensor(self, *args, **kwargs) -> bool: r"""Removes a :obj:`FeatureTensorType` object from the feature store. Args: - attr (TensorAttr): any relevant tensor attributes that correspond - to the tensor to remove. See :obj:`TensorAttr` documentation - for required and optional attributes. It is the job of - implementations of a FeatureStore to store this metadata in a - meaningful way that allows for tensor deletion from a - :obj:`TensorAttr` object. + **attr (TensorAttr): Any relevant tensor attributes that correspond + to the feature tensor. See the :class:`TensorAttr` + documentation for required and optional attributes. It is the + job of implementations of a :class:`FeatureStore` to store this + metadata in a meaningful way that allows for tensor retrieval + from a :class:`TensorAttr` object. Returns: - bool: whether deletion was succesful. + bool: Whether deletion was succesful. """ attr = self._attr_cls.cast(*args, **kwargs) if not attr.is_fully_specified(): - raise ValueError(f"The input TensorAttr {attr} is not fully " - f"specified. Please fully specify the input " - f"by specifying all UNSET fields.") - + raise ValueError(f"The input TensorAttr '{attr}' is not fully " + f"specified. Please fully specify the input by " + f"specifying all 'UNSET' fields.") self._remove_tensor(attr) def update_tensor(self, tensor: FeatureTensorType, *args, **kwargs) -> bool: - r"""Updates a :obj:`FeatureTensorType` object with a new value. + r"""Updates a :class:`FeatureTensorType` object with a new value. implementor classes can choose to define more efficient update methods; the default performs a removal and insertion. Args: - tensor (FeatureTensorType): the features to be added. - attr (TensorAttr): any relevant tensor attributes that correspond - to the old tensor. See :obj:`TensorAttr` documentation - for required and optional attributes. It is the job of - implementations of a FeatureStore to store this metadata in a - meaningful way that allows for tensor update from a - :obj:`TensorAttr` object. + tensor (FeatureTensorType): The feature tensor to be updated. + **attr (TensorAttr): Any relevant tensor attributes that correspond + to the feature tensor. See the :class:`TensorAttr` + documentation for required and optional attributes. It is the + job of implementations of a :class:`FeatureStore` to store this + metadata in a meaningful way that allows for tensor retrieval + from a :class:`TensorAttr` object. Returns: - bool: whether the update was succesful. + bool: Whether the update was succesful. """ attr = self._attr_cls.cast(*args, **kwargs) - if not attr.is_fully_specified(): - raise ValueError(f"The input TensorAttr {attr} is not fully " - f"specified. Please fully specify the input " - f"by specifying all UNSET fields.") - self.remove_tensor(attr) return self.put_tensor(tensor, attr) # :obj:`AttrView` methods ################################################# def view(self, *args, **kwargs) -> AttrView: - r"""Returns an :obj:`AttrView` of the feature store, with the defined + r"""Returns an :class:`AttrView` of the feature store, with the defined attributes set.""" - return AttrView(self, self._attr_cls.cast(*args, **kwargs)) + attr = self._attr_cls.cast(*args, **kwargs) + return AttrView(self, attr) # Python built-ins ######################################################## def __setitem__(self, key: TensorAttr, value: FeatureTensorType): r"""Supports store[tensor_attr] = tensor.""" + # CastMixin will handle the case of key being a tuple or TensorAttr + # object: key = self._attr_cls.cast(key) - # We need to fully specify the key for __setitem__ as it does not make - # sense to work with a view here. + # sense to work with a view here: key.fully_specify() self.put_tensor(value, key) - def __getitem__(self, key: TensorAttr): + def __getitem__(self, key: TensorAttr) -> Any: r"""Supports pythonic indexing into the feature store. In particular, the following rules are followed for indexing: - * Fully-specified indexes will produce a Tensor output. A - fully-specified index specifies all the required attributes in - :obj:`TensorAttr`. + * A fully-specified :obj:`key` will produce a tensor output. - * Partially-specified indexes will produce an AttrView output, which - is a view on the FeatureStore. If a view is called, it will produce - a Tensor output from the corresponding (partially specified) - attributes. + * A partially-specified :obj:`key` will produce an :class:`AttrView` + output, which is a view on the :class:`FeatureStore`. If a view is + called, it will produce a tensor output from the corresponding + (partially specified) attributes. """ # CastMixin will handle the case of key being a tuple or TensorAttr - # object. + # object: attr = self._attr_cls.cast(key) if attr.is_fully_specified(): return self.get_tensor(attr) - return AttrView(self, attr) + return self.view(attr) def __delitem__(self, key: TensorAttr): r"""Supports del store[tensor_attr].""" + # CastMixin will handle the case of key being a tuple or TensorAttr + # object: key = self._attr_cls.cast(key) key.fully_specify() self.remove_tensor(key) @@ -423,8 +419,8 @@ def __delitem__(self, key: TensorAttr): def __iter__(self): raise NotImplementedError - def __eq__(self, __o: object) -> bool: - return id(self) == id(__o) + def __eq__(self, obj: object) -> bool: + return id(self) == id(obj) @abstractmethod def __len__(self): From b2064d54022fdb850742b8dcc640369c2f5c6ab8 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Fri, 29 Apr 2022 13:53:53 +0200 Subject: [PATCH 17/18] pass --- torch_geometric/typing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_geometric/typing.py b/torch_geometric/typing.py index 4fdec3eccd87..60bce93dfd59 100644 --- a/torch_geometric/typing.py +++ b/torch_geometric/typing.py @@ -1,6 +1,6 @@ from typing import Dict, List, Optional, Tuple, Union -import numpy +import numpy as np import torch from torch import Tensor from torch_sparse import SparseTensor @@ -23,7 +23,7 @@ Metadata = Tuple[List[NodeType], List[EdgeType]] # A representation of a feature tensor -FeatureTensorType = Union[torch.TensorType, numpy.ndarray] +FeatureTensorType = Union[torch.Tensor, np.ndarray] # Types for message passing ################################################### From 23d4e582549b508703b6fbdebe5296521fa889f9 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Fri, 29 Apr 2022 14:12:42 +0200 Subject: [PATCH 18/18] pass --- test/data/test_feature_store.py | 61 +++++++++++++++------------------ 1 file changed, 28 insertions(+), 33 deletions(-) diff --git a/test/data/test_feature_store.py b/test/data/test_feature_store.py index 4fc94ee94300..c7c77b6c985a 100644 --- a/test/data/test_feature_store.py +++ b/test/data/test_feature_store.py @@ -18,40 +18,38 @@ def __init__(self): super().__init__() self.store = {} - @classmethod - def key(cls, attr: TensorAttr): - return (attr.group_name or '', attr.attr_name or '') + @staticmethod + def key(attr: TensorAttr) -> str: + return (attr.group_name, attr.attr_name) def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool: index = attr.index - # Not set or None indices define the obvious index + # Not set or None indices define the obvious index: if index is None: index = torch.arange(0, tensor.shape[0]) - # Store the index as a column - self.store[MyFeatureStore.key(attr)] = torch.cat( - (index.reshape(-1, 1), tensor), dim=1) + # Store the index as a column: + self.store[MyFeatureStore.key(attr)] = (index, tensor) return True def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]: - tensor = self.store.get(MyFeatureStore.key(attr), None) + index, tensor = self.store.get(MyFeatureStore.key(attr), (None, None)) + if tensor is None: return None - # Not set or None indices return the whole tensor + # Not set or None indices return the whole tensor: if attr.index is None: - return tensor[:, 1:] - - # Index into the tensor - indices = torch.cat([(tensor[:, 0] == v).nonzero() - for v in attr.index]).reshape(1, -1)[0] + return tensor - return torch.index_select(tensor[:, 1:], 0, indices) + idx = torch.cat([(index == v).nonzero() for v in attr.index]).view(-1) + return tensor[idx] def _remove_tensor(self, attr: TensorAttr) -> bool: del self.store[MyFeatureStore.key(attr)] + return True def __len__(self): raise NotImplementedError @@ -70,9 +68,9 @@ def __init__(self): super().__init__() self._attr_cls = MyTensorAttrNoGroupName - @classmethod - def key(cls, attr: TensorAttr): - return attr.attr_name or '' + @staticmethod + def key(attr: TensorAttr) -> str: + return attr.attr_name def __len__(self): raise NotImplementedError @@ -82,33 +80,30 @@ def test_feature_store(): r"""Tests basic API and indexing functionality of a feature store.""" store = MyFeatureStore() tensor = torch.Tensor([[0, 0, 0], [1, 1, 1], [2, 2, 2]]) - index = torch.Tensor([0, 1, 2]) - attr_name = 'feat' group_name = 'A' - + attr_name = 'feat' + index = torch.tensor([0, 1, 2]) attr = TensorAttr(group_name, attr_name, index) - # Normal API + # Normal API: store.put_tensor(tensor, attr) assert torch.equal(store.get_tensor(attr), tensor) assert torch.equal( - store.get_tensor(group_name, attr_name, torch.Tensor([0, 2])), - tensor[[0, 2]], + store.get_tensor(group_name, attr_name, index=torch.tensor([0, 2])), + tensor[torch.tensor([0, 2])], ) assert store.get_tensor(None, None, index) is None store.remove_tensor(group_name, attr_name, None) assert store.get_tensor(attr) is None - # Views - view = store.view(TensorAttr(group_name=group_name)) + # Views: + view = store.view(group_name=group_name) view.attr_name = attr_name view['index'] = index assert view == AttrView(store, TensorAttr(group_name, attr_name, index)) - # Indexing - - # Setting via indexing + # Indexing: store[group_name, attr_name, index] = tensor # Fully-specified forms, all of which produce a tensor output @@ -134,7 +129,7 @@ def test_feature_store(): assert isinstance(store[group_name], AttrView) assert torch.equal(store[group_name](), tensor) - # Deletion + # Deletion: del store[group_name, attr_name, index] assert store[group_name, attr_name, index] is None del store[group_name] @@ -144,14 +139,14 @@ def test_feature_store(): def test_feature_store_override(): store = MyFeatureStoreNoGroupName() tensor = torch.Tensor([[0, 0, 0], [1, 1, 1], [2, 2, 2]]) - index = torch.Tensor([0, 1, 2]) attr_name = 'feat' + index = torch.tensor([0, 1, 2]) - # Only use attr_name and index, in that order + # Only use attr_name and index, in that order: store[attr_name, index] = tensor - # A few assertions to ensure group_name is not needed + # A few assertions to ensure group_name is not needed: assert isinstance(store[attr_name], AttrView) assert torch.equal(store[attr_name, index], tensor) assert torch.equal(store[attr_name][index], tensor)