From b6b6b7ec7bc94b12da24e39fe9cadb051da2f252 Mon Sep 17 00:00:00 2001 From: Hatem Helal Date: Wed, 7 Sep 2022 20:23:14 +0100 Subject: [PATCH] split test files --- .../{test_mask.py => test_index_to_mask.py} | 21 +---------------- test/transforms/test_mask_to_index.py | 23 +++++++++++++++++++ 2 files changed, 24 insertions(+), 20 deletions(-) rename test/transforms/{test_mask.py => test_index_to_mask.py} (55%) create mode 100644 test/transforms/test_mask_to_index.py diff --git a/test/transforms/test_mask.py b/test/transforms/test_index_to_mask.py similarity index 55% rename from test/transforms/test_mask.py rename to test/transforms/test_index_to_mask.py index ff50ca6b68ee0..9e87641ca650b 100644 --- a/test/transforms/test_mask.py +++ b/test/transforms/test_index_to_mask.py @@ -2,7 +2,7 @@ import torch from torch_geometric.data import Data -from torch_geometric.transforms import IndexToMask, MaskToIndex +from torch_geometric.transforms import IndexToMask def test_index_to_mask(): @@ -30,22 +30,3 @@ def test_index_to_mask(): with pytest.raises(ValueError): IndexToMask(sizes=(1, 2))(data) - - -def test_mask_to_index(): - assert MaskToIndex().__repr__() == 'MaskToIndex(attrs=None)' - - train_mask = torch.tensor([1, 0, 1, 1, 0, 0, 0]).bool() - val_mask = torch.tensor([0, 1, 0, 0, 1, 0, 0]).bool() - test_mask = torch.tensor([0, 0, 0, 0, 0, 1, 1]).bool() - data = Data(train_mask=train_mask, val_mask=val_mask, test_mask=test_mask) - - out = MaskToIndex()(data.clone()) - assert len(out) == 6 - assert hasattr(out, "train_indices") - assert hasattr(out, "val_indices") - assert hasattr(out, "test_indices") - - out = MaskToIndex(attrs="train_mask")(data.clone()) - assert len(out) == 4 - assert out.train_indices.numel() == 3 diff --git a/test/transforms/test_mask_to_index.py b/test/transforms/test_mask_to_index.py new file mode 100644 index 0000000000000..3d6f0756bd211 --- /dev/null +++ b/test/transforms/test_mask_to_index.py @@ -0,0 +1,23 @@ +import torch + +from torch_geometric.data import Data +from torch_geometric.transforms import MaskToIndex + + +def test_mask_to_index(): + assert MaskToIndex().__repr__() == 'MaskToIndex(attrs=None)' + + train_mask = torch.tensor([1, 0, 1, 1, 0, 0, 0]).bool() + val_mask = torch.tensor([0, 1, 0, 0, 1, 0, 0]).bool() + test_mask = torch.tensor([0, 0, 0, 0, 0, 1, 1]).bool() + data = Data(train_mask=train_mask, val_mask=val_mask, test_mask=test_mask) + + out = MaskToIndex()(data.clone()) + assert len(out) == 6 + assert hasattr(out, "train_indices") + assert hasattr(out, "val_indices") + assert hasattr(out, "test_indices") + + out = MaskToIndex(attrs="train_mask")(data.clone()) + assert len(out) == 4 + assert out.train_indices.numel() == 3