Skip to content

Commit

Permalink
Let NeighborLoader accept Tuple[FeatureStore, GraphStore] (#4817)
Browse files Browse the repository at this point in the history
Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
mananshah99 and rusty1s authored Jun 23, 2022
1 parent 85cddb3 commit 7ea71e3
Show file tree
Hide file tree
Showing 12 changed files with 553 additions and 90 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.0.5] - 2022-MM-DD
### Added
- Added support for `FeatureStore` and `GraphStore` in `NeighborLoader` ([#4817](https://github.com/pyg-team/pytorch_geometric/pull/4817))
- Added a `normalize` parameter to `dense_diff_pool` ([#4847](https://github.com/pyg-team/pytorch_geometric/pull/4847))
- Added `size=None` explanation to jittable `MessagePassing` modules in the documentation ([#4850](https://github.com/pyg-team/pytorch_geometric/pull/4850))
- Added documentation to the `DataLoaderIterator` class ([#4838](https://github.com/pyg-team/pytorch_geometric/pull/4838))
Expand Down
15 changes: 14 additions & 1 deletion test/data/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,14 @@ def test_basic_feature_store():
out = data.get_tensor(attr_name='x', index=None)
assert torch.equal(x, out)

# Get tensor size:
assert data.get_tensor_size(attr_name='x') == (20, 20)

# Get tensor attrs:
tensor_attrs = data.get_all_tensor_attrs()
assert len(tensor_attrs) == 1
assert tensor_attrs[0].attr_name == 'x'

# Remove tensor:
assert 'x' in data.__dict__['_store']
data.remove_tensor(attr_name='x', index=None)
Expand All @@ -271,6 +279,7 @@ def test_basic_feature_store():


def test_basic_graph_store():
r"""Test the core graph store API."""
data = Data()

edge_index = torch.LongTensor([[0, 1], [1, 2]])
Expand All @@ -285,7 +294,7 @@ def assert_equal_tensor_tuple(expected, actual):
# to confirm that `GraphStore` works as intended.
coo = adj.coo()[:-1]
csr = adj.csr()[:-1]
csc = adj.csc()[:-1]
csc = adj.csc()[-2::-1] # (row, colptr)

# Put:
data.put_edge_index(coo, layout='coo')
Expand All @@ -296,3 +305,7 @@ def assert_equal_tensor_tuple(expected, actual):
assert_equal_tensor_tuple(coo, data.get_edge_index('coo'))
assert_equal_tensor_tuple(csr, data.get_edge_index('csr'))
assert_equal_tensor_tuple(csc, data.get_edge_index('csc'))

# Get attrs:
edge_attrs = data.get_all_edge_attrs()
assert len(edge_attrs) == 3
30 changes: 15 additions & 15 deletions test/data/test_feature_store.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from dataclasses import dataclass
from typing import Optional
from typing import Dict, List, Optional, Tuple

import pytest
import torch
from torch import Tensor

from torch_geometric.data.feature_store import (
AttrView,
Expand All @@ -16,7 +17,7 @@
class MyFeatureStore(FeatureStore):
def __init__(self):
super().__init__()
self.store = {}
self.store: Dict[Tuple[str, str], Tensor] = {}

@staticmethod
def key(attr: TensorAttr) -> str:
Expand All @@ -36,7 +37,6 @@ def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool:

def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]:
index, tensor = self.store.get(MyFeatureStore.key(attr), (None, None))

if tensor is None:
return None

Expand All @@ -51,6 +51,12 @@ def _remove_tensor(self, attr: TensorAttr) -> bool:
del self.store[MyFeatureStore.key(attr)]
return True

def _get_tensor_size(self, attr: TensorAttr) -> Tuple:
return self._get_tensor(attr).size()

def get_all_tensor_attrs(self) -> List[str]:
return [TensorAttr(*key) for key in self.store.keys()]

def __len__(self):
raise NotImplementedError

Expand All @@ -68,16 +74,8 @@ def __init__(self):
super().__init__()
self._tensor_attr_cls = MyTensorAttrNoGroupName

@staticmethod
def key(attr: TensorAttr) -> str:
return attr.attr_name

def __len__(self):
raise NotImplementedError


def test_feature_store():
r"""Tests basic API and indexing functionality of a feature store."""
store = MyFeatureStore()
tensor = torch.Tensor([[0, 0, 0], [1, 1, 1], [2, 2, 2]])

Expand All @@ -93,9 +91,9 @@ def test_feature_store():
store.get_tensor(group_name, attr_name, index=torch.tensor([0, 2])),
tensor[torch.tensor([0, 2])],
)
assert store.get_tensor(None, None, index) is None
store.remove_tensor(group_name, attr_name, None)
assert store.get_tensor(attr) is None
with pytest.raises(KeyError):
_ = store.get_tensor(attr)

# Views:
view = store.view(group_name=group_name)
Expand Down Expand Up @@ -131,9 +129,11 @@ def test_feature_store():

# Deletion:
del store[group_name, attr_name, index]
assert store[group_name, attr_name, index] is None
with pytest.raises(KeyError):
_ = store[group_name, attr_name, index]
del store[group_name]
assert store[group_name]() is None
with pytest.raises(KeyError):
_ = store[group_name]()


def test_feature_store_override():
Expand Down
20 changes: 16 additions & 4 deletions test/data/test_graph_store.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Optional
from typing import Dict, Optional, Tuple

import pytest
import torch
from torch import Tensor
from torch_sparse import SparseTensor

from torch_geometric.data.graph_store import (
Expand All @@ -14,11 +16,11 @@
class MyGraphStore(GraphStore):
def __init__(self):
super().__init__()
self.store = {}
self.store: Dict[EdgeAttr, Tuple[Tensor, Tensor]] = {}

@staticmethod
def key(attr: EdgeAttr) -> str:
return f"{attr.edge_type or '<default>'}_{attr.layout}"
return (attr.edge_type, attr.layout.value)

def _put_edge_index(self, edge_index: EdgeTensorType,
edge_attr: EdgeAttr) -> bool:
Expand All @@ -27,6 +29,9 @@ def _put_edge_index(self, edge_index: EdgeTensorType,
def _get_edge_index(self, edge_attr: EdgeAttr) -> Optional[EdgeTensorType]:
return self.store.get(MyGraphStore.key(edge_attr), None)

def get_all_edge_attrs(self):
return [EdgeAttr(*key) for key in self.store.keys()]


def test_graph_store():
graph_store = MyGraphStore()
Expand All @@ -42,7 +47,7 @@ def assert_equal_tensor_tuple(expected, actual):
# to confirm that `GraphStore` works as intended.
coo = adj.coo()[:-1]
csr = adj.csr()[:-1]
csc = adj.csc()[:-1]
csc = adj.csc()[-2::-1] # (row, colptr)

# Put:
graph_store['edge', EdgeLayout.COO] = coo
Expand All @@ -53,3 +58,10 @@ def assert_equal_tensor_tuple(expected, actual):
assert_equal_tensor_tuple(coo, graph_store['edge', 'coo'])
assert_equal_tensor_tuple(csr, graph_store['edge', 'csr'])
assert_equal_tensor_tuple(csc, graph_store['edge', 'csc'])

# Get attrs:
edge_attrs = graph_store.get_all_edge_attrs()
assert len(edge_attrs) == 3

with pytest.raises(KeyError):
_ = graph_store['edge_2', 'coo']
36 changes: 25 additions & 11 deletions test/data/test_hetero_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,15 @@ def test_basic_feature_store():
out = data.get_tensor(group_name='paper', attr_name='x', index=None)
assert torch.equal(x, out)

# Get tensor size:
assert data.get_tensor_size(group_name='paper', attr_name='x') == (20, 20)

# Get tensor attrs:
tensor_attrs = data.get_all_tensor_attrs()
assert len(tensor_attrs) == 1
assert tensor_attrs[0].group_name == 'paper'
assert tensor_attrs[0].attr_name == 'x'

# Remove tensor:
assert 'x' in data['paper'].__dict__['_mapping']
data.remove_tensor(group_name='paper', attr_name='x', index=None)
Expand All @@ -437,7 +446,8 @@ def test_basic_graph_store():
data = HeteroData()

edge_index = torch.LongTensor([[0, 1], [1, 2]])
adj = torch_sparse.SparseTensor(row=edge_index[0], col=edge_index[1])
adj = torch_sparse.SparseTensor(row=edge_index[0], col=edge_index[1],
sparse_sizes=(3, 3))

def assert_equal_tensor_tuple(expected, actual):
assert len(expected) == len(actual)
Expand All @@ -448,17 +458,21 @@ def assert_equal_tensor_tuple(expected, actual):
# to confirm that `GraphStore` works as intended.
coo = adj.coo()[:-1]
csr = adj.csr()[:-1]
csc = adj.csc()[:-1]
csc = adj.csc()[-2::-1] # (row, colptr)

# Put:
data.put_edge_index(coo, layout='coo', edge_type='1')
data.put_edge_index(csr, layout='csr', edge_type='2')
data.put_edge_index(csc, layout='csc', edge_type='3')
data.put_edge_index(coo, layout='coo', edge_type=('a', 'to', 'b'))
data.put_edge_index(csr, layout='csr', edge_type=('a', 'to', 'c'))
data.put_edge_index(csc, layout='csc', edge_type=('b', 'to', 'c'))

# Get:
assert_equal_tensor_tuple(coo,
data.get_edge_index(layout='coo', edge_type='1'))
assert_equal_tensor_tuple(csr,
data.get_edge_index(layout='csr', edge_type='2'))
assert_equal_tensor_tuple(csc,
data.get_edge_index(layout='csc', edge_type='3'))
assert_equal_tensor_tuple(
coo, data.get_edge_index(layout='coo', edge_type=('a', 'to', 'b')))
assert_equal_tensor_tuple(
csr, data.get_edge_index(layout='csr', edge_type=('a', 'to', 'c')))
assert_equal_tensor_tuple(
csc, data.get_edge_index(layout='csc', edge_type=('b', 'to', 'c')))

# Get attrs:
edge_attrs = data.get_all_edge_attrs()
assert len(edge_attrs) == 3
Loading

0 comments on commit 7ea71e3

Please sign in to comment.