Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add IndexToMask and MaskToIndex transforms #5375

Merged
merged 11 commits into from
Sep 15, 2022
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.2.0] - 2022-MM-DD
### Added
- Added `IndexToMask` and `MaskToIndex` transforms ([#5375](https://github.com/pyg-team/pytorch_geometric/pull/5375))
- Added `FeaturePropagation` transform ([#5387](https://github.com/pyg-team/pytorch_geometric/pull/5387))
- Added `PositionalEncoding` ([#5381](https://github.com/pyg-team/pytorch_geometric/pull/5381))
- Consolidated sampler routines behind `torch_geometric.sampler`, enabling ease of extensibility in the future ([#5312](https://github.com/pyg-team/pytorch_geometric/pull/5312), [#5365](https://github.com/pyg-team/pytorch_geometric/pull/5365), [#5402](https://github.com/pyg-team/pytorch_geometric/pull/5402), [#5404](https://github.com/pyg-team/pytorch_geometric/pull/5404)), [#5418](https://github.com/pyg-team/pytorch_geometric/pull/5418))
Expand Down
91 changes: 91 additions & 0 deletions test/transforms/test_mask_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import copy

import torch

from torch_geometric.data import Data, HeteroData
from torch_geometric.transforms import IndexToMask, MaskToIndex


def test_index_to_mask():
assert str(IndexToMask()) == ('IndexToMask(attrs=None, sizes=None, '
'replace=False)')

train_index = torch.arange(0, 3)
test_index = torch.arange(3, 5)
data = Data(train_index=train_index, test_index=test_index, num_nodes=5)

out = IndexToMask(replace=True)(copy.copy(data))
assert len(out) == len(data)
assert out.train_mask.tolist() == [True, True, True, False, False]
assert out.test_mask.tolist() == [False, False, False, True, True]

out = IndexToMask(replace=False)(copy.copy(data))
assert len(out) == len(data) + 2

out = IndexToMask(sizes=6, replace=True)(copy.copy(data))
assert out.train_mask.tolist() == [True, True, True, False, False, False]
assert out.test_mask.tolist() == [False, False, False, True, True, False]

out = IndexToMask(attrs='train_index')(copy.copy(data))
assert len(out) == len(data) + 1
assert 'train_index' in out
assert 'train_mask' in out
assert 'test_index' in out
assert 'test_mask' not in out


def test_mask_to_index():
assert str(MaskToIndex()) == 'MaskToIndex(attrs=None, replace=False)'

train_mask = torch.tensor([True, True, True, False, False])
test_mask = torch.tensor([False, False, False, True, True])
data = Data(train_mask=train_mask, test_mask=test_mask)

out = MaskToIndex(replace=True)(copy.copy(data))
assert len(out) == len(data)
assert out.train_index.tolist() == [0, 1, 2]
assert out.test_index.tolist() == [3, 4]

out = MaskToIndex(replace=False)(copy.copy(data))
assert len(out) == len(data) + 2

out = MaskToIndex(attrs='train_mask')(copy.copy(data))
assert len(out) == len(data) + 1
assert 'train_mask' in out
assert 'train_index' in out
assert 'test_mask' in out
assert 'test_index' not in out


def test_hetero_index_to_mask():
data = HeteroData()
data['u'].train_index = torch.arange(0, 3)
data['u'].test_index = torch.arange(3, 5)
data['u'].num_nodes = 5

data['v'].train_index = torch.arange(0, 3)
data['v'].test_index = torch.arange(3, 5)
data['v'].num_nodes = 5

out = IndexToMask()(copy.copy(data))
assert len(out) == len(data) + 2
assert 'train_mask' in out['u']
assert 'test_mask' in out['u']
assert 'train_mask' in out['v']
assert 'test_mask' in out['v']


def test_hetero_mask_to_index():
data = HeteroData()
data['u'].train_mask = torch.tensor([True, True, True, False, False])
data['u'].test_mask = torch.tensor([False, False, False, True, True])

data['v'].train_mask = torch.tensor([True, True, True, False, False])
data['v'].test_mask = torch.tensor([False, False, False, True, True])

out = MaskToIndex()(copy.copy(data))
assert len(out) == len(data) + 2
assert 'train_index' in out['u']
assert 'test_index' in out['u']
assert 'train_index' in out['v']
assert 'test_index' in out['v']
3 changes: 3 additions & 0 deletions torch_geometric/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from .virtual_node import VirtualNode
from .add_positional_encoding import AddLaplacianEigenvectorPE, AddRandomWalkPE
from .feature_propagation import FeaturePropagation
from .mask import IndexToMask, MaskToIndex

__all__ = [
'BaseTransform',
Expand Down Expand Up @@ -108,6 +109,8 @@
'AddLaplacianEigenvectorPE',
'AddRandomWalkPE',
'FeaturePropagation',
'IndexToMask',
'MaskToIndex',
]

classes = __all__
Expand Down
126 changes: 126 additions & 0 deletions torch_geometric/transforms/mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from typing import List, Optional, Union

from torch_geometric.data import Data, HeteroData
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.data.storage import BaseStorage
from torch_geometric.transforms import BaseTransform
from torch_geometric.utils import index_to_mask, mask_to_index

AnyData = Union[Data, HeteroData]


def get_attrs_with_suffix(
attrs: Optional[List[str]],
store: BaseStorage,
suffix: str,
) -> List[str]:
if attrs is not None:
return attrs
return [key for key in store.keys() if key.endswith(suffix)]


def get_mask_size(attr: str, store: BaseStorage, size: Optional[int]) -> int:
if size is not None:
return size
return store.num_edges if store.is_edge_attr(attr) else store.num_nodes


@functional_transform('index_to_mask')
class IndexToMask(BaseTransform):
r"""Converts indices to a mask representation
(functional name: :obj:`index_to_mask`).

Args:
attrs (str, [str], optional): If given, will only perform index to mask
conversion for the given attributes. If omitted, will infer the
attributes from the suffix :obj:`_index`. (default: :obj:`None`)
sizes (int, [int], optional). The size of the mask. If set to
:obj:`None`, an automatically sized tensor is returned. The number
of nodes will be used by default, except for edge attributes which
will use the number of edges as the mask size.
(default: :obj:`None`)
replace (bool, optional): if set to :obj:`True` replaces the index
attributes with mask tensors. (default: :obj:`False`)
"""
def __init__(
self,
attrs: Optional[Union[str, List[str]]] = None,
sizes: Optional[Union[int, List[int]]] = None,
replace: bool = False,
):
self.attrs = [attrs] if isinstance(attrs, str) else attrs
self.sizes = sizes
self.replace = replace

def __call__(
self,
data: Union[Data, HeteroData],
) -> Union[Data, HeteroData]:
for store in data.stores:
attrs = get_attrs_with_suffix(self.attrs, store, '_index')

sizes = self.sizes or ([None] * len(attrs))
if isinstance(sizes, int):
sizes = [self.sizes] * len(attrs)

if len(attrs) != len(sizes):
raise ValueError(
f"The number of attributes (got {len(attrs)}) must match "
f"the number of sizes provided (got {len(sizes)}).")

for attr, size in zip(attrs, sizes):
if attr not in store:
continue
size = get_mask_size(attr, store, size)
mask = index_to_mask(store[attr], size=size)
store[f'{attr[:-6]}_mask'] = mask
if self.replace:
del store[attr]

return data

def __repr__(self) -> str:
return (f'{self.__class__.__name__}(attrs={self.attrs}, '
f'sizes={self.sizes}, replace={self.replace})')


@functional_transform('mask_to_index')
class MaskToIndex(BaseTransform):
r"""Converts a mask to an index representation
(functional name: :obj:`mask_to_index`).

Args:
attrs (str, [str], optional): If given, will only perform mask to index
conversion for the given attributes. If omitted, will infer the
attributes from the suffix :obj:`_mask` (default: :obj:`None`)
replace (bool, optional): if set to :obj:`True` replaces the mask
attributes with index tensors. (default: :obj:`False`)
"""
def __init__(
self,
attrs: Optional[Union[str, List[str]]] = None,
replace: bool = False,
):
self.attrs = [attrs] if isinstance(attrs, str) else attrs
self.replace = replace

def __call__(
self,
data: Union[Data, HeteroData],
) -> Union[Data, HeteroData]:
for store in data.stores:
attrs = get_attrs_with_suffix(self.attrs, store, '_mask')

for attr in attrs:
if attr not in store:
continue
index = mask_to_index(store[attr])
store[f'{attr[:-5]}_index'] = index
if self.replace:
del store[attr]

return data

def __repr__(self) -> str:
return (f'{self.__class__.__name__}(attrs={self.attrs}, '
f'replace={self.replace})')