Skip to content

Commit

Permalink
split test files
Browse files Browse the repository at this point in the history
  • Loading branch information
hatemhelal committed Sep 7, 2022
1 parent fbb7fda commit b6b6b7e
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch

from torch_geometric.data import Data
from torch_geometric.transforms import IndexToMask, MaskToIndex
from torch_geometric.transforms import IndexToMask


def test_index_to_mask():
Expand Down Expand Up @@ -30,22 +30,3 @@ def test_index_to_mask():

with pytest.raises(ValueError):
IndexToMask(sizes=(1, 2))(data)


def test_mask_to_index():
assert MaskToIndex().__repr__() == 'MaskToIndex(attrs=None)'

train_mask = torch.tensor([1, 0, 1, 1, 0, 0, 0]).bool()
val_mask = torch.tensor([0, 1, 0, 0, 1, 0, 0]).bool()
test_mask = torch.tensor([0, 0, 0, 0, 0, 1, 1]).bool()
data = Data(train_mask=train_mask, val_mask=val_mask, test_mask=test_mask)

out = MaskToIndex()(data.clone())
assert len(out) == 6
assert hasattr(out, "train_indices")
assert hasattr(out, "val_indices")
assert hasattr(out, "test_indices")

out = MaskToIndex(attrs="train_mask")(data.clone())
assert len(out) == 4
assert out.train_indices.numel() == 3
23 changes: 23 additions & 0 deletions test/transforms/test_mask_to_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import torch

from torch_geometric.data import Data
from torch_geometric.transforms import MaskToIndex


def test_mask_to_index():
assert MaskToIndex().__repr__() == 'MaskToIndex(attrs=None)'

train_mask = torch.tensor([1, 0, 1, 1, 0, 0, 0]).bool()
val_mask = torch.tensor([0, 1, 0, 0, 1, 0, 0]).bool()
test_mask = torch.tensor([0, 0, 0, 0, 0, 1, 1]).bool()
data = Data(train_mask=train_mask, val_mask=val_mask, test_mask=test_mask)

out = MaskToIndex()(data.clone())
assert len(out) == 6
assert hasattr(out, "train_indices")
assert hasattr(out, "val_indices")
assert hasattr(out, "test_indices")

out = MaskToIndex(attrs="train_mask")(data.clone())
assert len(out) == 4
assert out.train_indices.numel() == 3

0 comments on commit b6b6b7e

Please sign in to comment.