Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FeatureStore abstraction definition #4534

Merged
merged 23 commits into from
Apr 29, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions test/data/test_feature_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from typing import Optional

import torch

from torch_geometric.data.feature_store import FeatureStore, TensorAttr
from torch_geometric.typing import FeatureTensorType


class MyFeatureStore(FeatureStore):
r"""A basic feature store, does NOT implement all functionality of a
fully-fledged feature store. Only works for Torch tensors."""
def __init__(self):
super().__init__(backend='test')
self.store = {}

@classmethod
def key(cls, attr: TensorAttr):
return (attr.tensor_type or '', attr.node_type or '', attr.graph_type
or '')

def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool:
self.store[MyFeatureStore.key(attr)] = torch.cat(
(attr.index.reshape(-1, 1), tensor), dim=1)
return True

def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]:
tensor = self.store.get(MyFeatureStore.key(attr), None)
if tensor is None:
return None
if attr.index is not None:
mananshah99 marked this conversation as resolved.
Show resolved Hide resolved
indices = torch.cat([(tensor[:, 0] == v).nonzero()
for v in attr.index]).reshape(1, -1)[0]

return torch.index_select(tensor[:, 1:], 0, indices)
return tensor[:, 1:]

def _remove_tensor(self, attr: TensorAttr) -> bool:
if attr.index is not None:
raise NotImplementedError
del self.store[MyFeatureStore.key(attr)]

def __len__(self):
raise NotImplementedError


def test_feature_store():
r"""Tests basic API and indexing functionality of a feature store."""

store = MyFeatureStore()
tensor = torch.Tensor([[0, 0, 0], [1, 1, 1], [2, 2, 2]])
index = torch.Tensor([0, 1, 2])

tensor_type = 'feat'
node_type = 'A'
graph_type = 'graph'

attr = TensorAttr(index, tensor_type, node_type, graph_type)
mananshah99 marked this conversation as resolved.
Show resolved Hide resolved

# Normal API
store.put_tensor(tensor, attr)
assert torch.equal(store.get_tensor(attr), tensor)
assert torch.equal(
store.get_tensor(
(torch.Tensor([0, 2]), tensor_type, node_type, graph_type)),
tensor[[0, 2]],
)
assert store.get_tensor((index)) is None
store.remove_tensor((None, tensor_type, node_type, graph_type))
assert store.get_tensor(attr) is None

# Indexing
store[attr] = tensor
assert torch.equal(store[attr], tensor)
assert torch.equal(
store[(torch.Tensor([0, 2]), tensor_type, node_type, graph_type)],
tensor[[0, 2]],
)
assert store[(index)] is None
del store[(None, tensor_type, node_type, graph_type)]
assert store.get_tensor(attr) is None

# Advanced indexing
store[attr] = tensor
assert (torch.equal(
store[TensorAttr(node_type=node_type,
graph_type=graph_type)].feat[index], tensor))
220 changes: 220 additions & 0 deletions torch_geometric/data/feature_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
r"""
This class defines the abstraction for a Graph feature store. The goal of a
mananshah99 marked this conversation as resolved.
Show resolved Hide resolved
feature store is to abstract away all node and edge feature memory management
so that varying implementations can allow for independent scale-out.

This particular feature store abstraction makes a few key assumptions:
* The features we care about storing are all associated with some sort of
`index`; explicitly for PyG the the index of the node in the graph (or
the heterogeneous component of the graph it resides in).
* A feature can uniquely be identified from (a) its index and (b) any other
associated attributes specified in :obj:`TensorAttr`.

It is the job of a feature store implementor class to handle these assumptions
properly. For example, a simple in-memory feature store implementation may
concatenate all metadata values with a feature index and use this as a unique
index in a KV store. More complicated implementations may choose to partition
features in interesting manners based on the provided metadata.

Major TODOs for future implementation:
* Async `put` and `get` functionality
"""
from abc import abstractmethod
from collections.abc import MutableMapping
from dataclasses import dataclass
from typing import Any, Optional

import numpy as np
import torch

from torch_geometric.typing import FeatureTensorType
from torch_geometric.utils.mixin import CastMixin


@dataclass
class TensorAttr(CastMixin):
mananshah99 marked this conversation as resolved.
Show resolved Hide resolved
r"""Defines the attributes of a :obj:`FeatureStore` tensor."""

# The node indices the rows of the tensor correspond to
index: Optional[FeatureTensorType] = None
mananshah99 marked this conversation as resolved.
Show resolved Hide resolved

# The type of the feature tensor (may be used if there are multiple
# different feature tensors for the same node index)
tensor_type: Optional[str] = None

# The type of the nodes that the tensor corresponds to (may be used for
# hetereogeneous graphs)
node_type: Optional[str] = None
mananshah99 marked this conversation as resolved.
Show resolved Hide resolved

# The type of the graph that the nodes correspond to (may be used if a
# feature store supports multiple graphs)
graph_type: Optional[str] = None
mananshah99 marked this conversation as resolved.
Show resolved Hide resolved


class AttrView:
r"""A view of a :obj:`FeatureStore` that is obtained from an incomplete
specification of attributes. This view stores a reference to the
originating feature store as well as a :obj:`TensorAttr` object that
represents the view's (incomplete) state.

As a result, store[TensorAttr(...)].tensor_type[idx] allows for indexing
into the store.
"""
_store: 'FeatureStore'
attr: TensorAttr

def __init__(self, store, attr):
self._store = store
self.attr = attr
mananshah99 marked this conversation as resolved.
Show resolved Hide resolved

def __getattr__(self, tensor_type):
r"""Supports attr_view.attr"""
self.attr.tensor_type = tensor_type
return self
mananshah99 marked this conversation as resolved.
Show resolved Hide resolved

def __getitem__(self, index: FeatureTensorType):
r"""Supports attr_view.attr[idx]"""
self.attr.index = index
return self._store.get_tensor(self.attr)

def __repr__(self) -> str:
return f'AttrView(store={self._store}, attr={self.attr})'


class FeatureStore(MutableMapping):
def __init__(self, backend: Any):
mananshah99 marked this conversation as resolved.
Show resolved Hide resolved
r"""Initializes the feature store with a specified backend."""
self.backend = backend

# Core (CRUD) #############################################################

@abstractmethod
def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool:
r"""Implemented by :obj:`FeatureStore` subclasses."""
pass

def put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool:
mananshah99 marked this conversation as resolved.
Show resolved Hide resolved
mananshah99 marked this conversation as resolved.
Show resolved Hide resolved
r"""Synchronously adds a :obj:`FeatureTensorType` object to the feature
store.

Args:
tensor (FeatureTensorType): the features to be added.
attr (TensorAttr): any relevant tensor attributes that correspond
to the feature tensor. See the :obj:`TensorAttr` documentation
for required and optional attributes. It is the job of
implementations of a FeatureStore to store this metadata in a
meaningful way that allows for tensor retrieval from a
:obj:`TensorAttr` object.
Returns:
bool: whether insertion was successful.
"""
attr = TensorAttr.cast(attr)
assert attr.index is not None
mananshah99 marked this conversation as resolved.
Show resolved Hide resolved
assert attr.index.size(dim=0) == tensor.size(dim=-1)
return self._put_tensor(tensor, attr)

@abstractmethod
def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]:
r"""Implemented by :obj:`FeatureStore` subclasses."""
pass

def get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]:
mananshah99 marked this conversation as resolved.
Show resolved Hide resolved
r"""Synchronously obtains a :obj:`FeatureTensorType` object from the
feature store. Feature store implementors guarantee that the call
mananshah99 marked this conversation as resolved.
Show resolved Hide resolved
get_tensor(put_tensor(tensor, attr), attr) = tensor.

Args:
attr (TensorAttr): any relevant tensor attributes that correspond
to the tensor to obtain. See :obj:`TensorAttr` documentation
for required and optional attributes. It is the job of
implementations of a FeatureStore to store this metadata in a
meaningful way that allows for tensor retrieval from a
:obj:`TensorAttr` object.
Returns:
FeatureTensorType, optional: a tensor of the same type as the
index, or None if no tensor was found.
"""
def maybe_to_torch(x):
return torch.from_numpy(x) if isinstance(
attr.index, torch.Tensor) and isinstance(x, np.ndarray) else x
mananshah99 marked this conversation as resolved.
Show resolved Hide resolved

attr = TensorAttr.cast(attr)
assert attr.index is not None

return maybe_to_torch(self._get_tensor(attr))

@abstractmethod
def _remove_tensor(self, attr: TensorAttr) -> bool:
r"""Implemented by :obj:`FeatureStore` subclasses."""
pass

def remove_tensor(self, attr: TensorAttr) -> bool:
r"""Removes a :obj:`FeatureTensorType` object from the feature store.

Args:
attr (TensorAttr): any relevant tensor attributes that correspond
to the tensor to remove. See :obj:`TensorAttr` documentation
for required and optional attributes. It is the job of
implementations of a FeatureStore to store this metadata in a
meaningful way that allows for tensor deletion from a
:obj:`TensorAttr` object.

Returns:
bool: whether deletion was succesful.
"""
attr = TensorAttr.cast(attr)
self._remove_tensor(attr)

def update_tensor(self, tensor: FeatureTensorType,
attr: TensorAttr) -> bool:
r"""Updates a :obj:`FeatureTensorType` object with a new value.
implementor classes can choose to define more efficient update methods;
the default performs a removal and insertion.

Args:
tensor (FeatureTensorType): the features to be added.
attr (TensorAttr): any relevant tensor attributes that correspond
to the old tensor. See :obj:`TensorAttr` documentation
for required and optional attributes. It is the job of
implementations of a FeatureStore to store this metadata in a
meaningful way that allows for tensor update from a
:obj:`TensorAttr` object.

Returns:
bool: whether the update was succesful.
"""
attr = TensorAttr.cast(attr)
self.remove_tensor(attr)
return self.put_tensor(tensor, attr)

# Python built-ins ########################################################

def __repr__(self) -> str:
return f'{self.__class__.__name__}(backend={self.backend})'

def __setitem__(self, key: TensorAttr, value: FeatureTensorType):
r"""Supports store[tensor_attr] = tensor."""
key = TensorAttr.cast(key)
assert key.index is not None
self.put_tensor(value, key)

def __getitem__(self, key: TensorAttr):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getattr and setattr equivalents?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think a FeatureStore should implement getattr and setattr; imo, these should only be implemented for views on the store. I don't think it's particularly clean to have

store.group_name -> AttrView(group_name=group_name)

as this syntax seems more confusing to me than clarifying.

Copy link
Member

@rusty1s rusty1s Apr 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thought it makes sense to implement for stores without group names (like PyG data objects). We could require that the output is fully specified in case we allow it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't the only way for an output to be fully specified through getattr to allow for chaining, which necessitates that getattr can return an AttrView? That feels odd to me.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If TensorAttr sets both group_name and index to None by default, then store.{attr_name} should give you a tensor. I am okay with leaving this out for now.

r"""Supports store[tensor_attr]. If tensor_attr has index specified,
will obtain the corresponding features from the store. Otherwise, will
return an :obj:`AttrView` which can be indexed independently."""
key = TensorAttr.cast(key)
if key.index is not None:
return self.get_tensor(key)
return AttrView(self, key)

def __delitem__(self, key: TensorAttr):
r"""Supports del store[tensor_attr]."""
key = TensorAttr.cast(key)
self.remove_tensor(key)

def __iter__(self):
raise NotImplementedError

@abstractmethod
def __len__(self):
pass
5 changes: 5 additions & 0 deletions torch_geometric/typing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Dict, List, Optional, Tuple, Union

import numpy
mananshah99 marked this conversation as resolved.
Show resolved Hide resolved
import torch
from torch import Tensor
from torch_sparse import SparseTensor

Expand All @@ -20,6 +22,9 @@

Metadata = Tuple[List[NodeType], List[EdgeType]]

# A representation of a feature tensor
FeatureTensorType = Union[torch.TensorType, numpy.ndarray]
rusty1s marked this conversation as resolved.
Show resolved Hide resolved

# Types for message passing ###################################################

Adj = Union[Tensor, SparseTensor]
Expand Down
14 changes: 14 additions & 0 deletions torch_geometric/utils/mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
class CastMixin:
mananshah99 marked this conversation as resolved.
Show resolved Hide resolved
@classmethod
def cast(cls, *args, **kwargs): # TODO Can we apply this recursively?
if len(args) == 1 and len(kwargs) == 0:
elem = args[0]
if elem is None:
return None
if isinstance(elem, CastMixin):
return elem
if isinstance(elem, (tuple, list)):
return cls(*elem)
if isinstance(elem, dict):
return cls(**elem)
return cls(*args, **kwargs)