From e531e47b2550387be03c3deb89f62a89b43b1197 Mon Sep 17 00:00:00 2001 From: Paul Louis Date: Sun, 12 Feb 2023 17:03:28 -0500 Subject: [PATCH 1/3] Add coverage for sign transforms --- test/transforms/test_sign.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 test/transforms/test_sign.py diff --git a/test/transforms/test_sign.py b/test/transforms/test_sign.py new file mode 100644 index 000000000000..3be4e0643892 --- /dev/null +++ b/test/transforms/test_sign.py @@ -0,0 +1,27 @@ +import copy + +import torch + +from torch_geometric.data import Data +from torch_geometric.transforms import SIGN + + +def test_sign_transform(): + assert SIGN(K=1).__repr__() == 'SIGN(K=1)' + edge_index = torch.tensor([ + [0, 1, 2, 3, 3, 4], + [1, 0, 3, 2, 4, 3], + ]) + x = torch.ones(5, 3) + expected_x1 = torch.tensor([[1, 1, 1], [1, 1, 1], [0.7071, 0.7071, 0.7071], + [1.4142, 1.4142, 1.4142], + [0.7071, 0.7071, 0.7071]]) + expected_x2 = torch.ones(5, 3) + data = Data(x=x, edge_index=edge_index, num_nodes=5) + transform = SIGN(K=2) + transformed_data = transform(copy.copy(data)) + assert len(transformed_data) == 5 + assert transformed_data.edge_index.tolist() == edge_index.tolist() + assert transformed_data.x.tolist() == x.tolist() + assert torch.allclose(transformed_data.x1, expected_x1, atol=1e-4) + assert torch.allclose(transformed_data.x2, expected_x2, atol=1e-4) From 5ae09218e979f576c0c423189f71490f73564686 Mon Sep 17 00:00:00 2001 From: Paul Louis Date: Sun, 12 Feb 2023 17:06:42 -0500 Subject: [PATCH 2/3] Update Changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 118daf6618d4..e03dcc4a6d1b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -82,7 +82,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug in `Data.subgraph()` and `HeteroData.subgraph()` ([#6613](https://github.com/pyg-team/pytorch_geometric/pull/6613) - Fixed a bug in `PNAConv` and `DegreeScalerAggregation` to correctly incorporate degree statistics of isolated nodes ([#6609](https://github.com/pyg-team/pytorch_geometric/pull/6609)) -- Improved code coverage ([#6523](https://github.com/pyg-team/pytorch_geometric/pull/6523), [#6538](https://github.com/pyg-team/pytorch_geometric/pull/6538), [#6555](https://github.com/pyg-team/pytorch_geometric/pull/6555), [#6558](https://github.com/pyg-team/pytorch_geometric/pull/6558), [#6568](https://github.com/pyg-team/pytorch_geometric/pull/6568), [#6573](https://github.com/pyg-team/pytorch_geometric/pull/6573), [#6578](https://github.com/pyg-team/pytorch_geometric/pull/6578), [#6597](https://github.com/pyg-team/pytorch_geometric/pull/6597), [#6600](https://github.com/pyg-team/pytorch_geometric/pull/6600), [#6618](https://github.com/pyg-team/pytorch_geometric/pull/6618), [#6619](https://github.com/pyg-team/pytorch_geometric/pull/6619), [#6621](https://github.com/pyg-team/pytorch_geometric/pull/6621), [#6623](https://github.com/pyg-team/pytorch_geometric/pull/6623), [#6637](https://github.com/pyg-team/pytorch_geometric/pull/6637), [#6638](https://github.com/pyg-team/pytorch_geometric/pull/6638), [#6640](https://github.com/pyg-team/pytorch_geometric/pull/6640), [#6645](https://github.com/pyg-team/pytorch_geometric/pull/6645), [#6648](https://github.com/pyg-team/pytorch_geometric/pull/6648), [#6647](https://github.com/pyg-team/pytorch_geometric/pull/6647), [#6657](https://github.com/pyg-team/pytorch_geometric/pull/6657), [#6662](https://github.com/pyg-team/pytorch_geometric/pull/6662), [#6664](https://github.com/pyg-team/pytorch_geometric/pull/6664), [#6667](https://github.com/pyg-team/pytorch_geometric/pull/6667), [#6668](https://github.com/pyg-team/pytorch_geometric/pull/6668), [#6669](https://github.com/pyg-team/pytorch_geometric/pull/6669), [#6670](https://github.com/pyg-team/pytorch_geometric/pull/6670), [#6671](https://github.com/pyg-team/pytorch_geometric/pull/6671), [#6673](https://github.com/pyg-team/pytorch_geometric/pull/6673), [#6675](https://github.com/pyg-team/pytorch_geometric/pull/6675), [#6676](https://github.com/pyg-team/pytorch_geometric/pull/6676), [#6677](https://github.com/pyg-team/pytorch_geometric/pull/6677), [#6678](https://github.com/pyg-team/pytorch_geometric/pull/6678), [#6681](https://github.com/pyg-team/pytorch_geometric/pull/6681)) +- Improved code coverage ([#6523](https://github.com/pyg-team/pytorch_geometric/pull/6523), [#6538](https://github.com/pyg-team/pytorch_geometric/pull/6538), [#6555](https://github.com/pyg-team/pytorch_geometric/pull/6555), [#6558](https://github.com/pyg-team/pytorch_geometric/pull/6558), [#6568](https://github.com/pyg-team/pytorch_geometric/pull/6568), [#6573](https://github.com/pyg-team/pytorch_geometric/pull/6573), [#6578](https://github.com/pyg-team/pytorch_geometric/pull/6578), [#6597](https://github.com/pyg-team/pytorch_geometric/pull/6597), [#6600](https://github.com/pyg-team/pytorch_geometric/pull/6600), [#6618](https://github.com/pyg-team/pytorch_geometric/pull/6618), [#6619](https://github.com/pyg-team/pytorch_geometric/pull/6619), [#6621](https://github.com/pyg-team/pytorch_geometric/pull/6621), [#6623](https://github.com/pyg-team/pytorch_geometric/pull/6623), [#6637](https://github.com/pyg-team/pytorch_geometric/pull/6637), [#6638](https://github.com/pyg-team/pytorch_geometric/pull/6638), [#6640](https://github.com/pyg-team/pytorch_geometric/pull/6640), [#6645](https://github.com/pyg-team/pytorch_geometric/pull/6645), [#6648](https://github.com/pyg-team/pytorch_geometric/pull/6648), [#6647](https://github.com/pyg-team/pytorch_geometric/pull/6647), [#6657](https://github.com/pyg-team/pytorch_geometric/pull/6657), [#6662](https://github.com/pyg-team/pytorch_geometric/pull/6662), [#6664](https://github.com/pyg-team/pytorch_geometric/pull/6664), [#6667](https://github.com/pyg-team/pytorch_geometric/pull/6667), [#6668](https://github.com/pyg-team/pytorch_geometric/pull/6668), [#6669](https://github.com/pyg-team/pytorch_geometric/pull/6669), [#6670](https://github.com/pyg-team/pytorch_geometric/pull/6670), [#6671](https://github.com/pyg-team/pytorch_geometric/pull/6671), [#6673](https://github.com/pyg-team/pytorch_geometric/pull/6673), [#6675](https://github.com/pyg-team/pytorch_geometric/pull/6675), [#6676](https://github.com/pyg-team/pytorch_geometric/pull/6676), [#6677](https://github.com/pyg-team/pytorch_geometric/pull/6677), [#6678](https://github.com/pyg-team/pytorch_geometric/pull/6678), [#6681](https://github.com/pyg-team/pytorch_geometric/pull/6681), [#6683](https://github.com/pyg-team/pytorch_geometric/pull/6683)) - Fixed a bug in which `data.to_heterogeneous()` filtered attributs in the wrong dimension ([#6522](https://github.com/pyg-team/pytorch_geometric/pull/6522)) - Breaking Change: Temporal sampling will now also sample nodes with an equal timestamp to the seed time (requires `pyg-lib>0.1.0`) ([#6517](https://github.com/pyg-team/pytorch_geometric/pull/6517)) - Changed `DataLoader` workers with affinity to start at `cpu0` ([#6512](https://github.com/pyg-team/pytorch_geometric/pull/6512)) From 1408d9f8b9cec54905c56d31888134b4967e3d40 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 13 Feb 2023 07:37:52 +0100 Subject: [PATCH 3/3] update --- test/transforms/test_sign.py | 37 ++++++++++++++++++++---------------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/test/transforms/test_sign.py b/test/transforms/test_sign.py index 3be4e0643892..94c371938238 100644 --- a/test/transforms/test_sign.py +++ b/test/transforms/test_sign.py @@ -1,27 +1,32 @@ -import copy - import torch from torch_geometric.data import Data from torch_geometric.transforms import SIGN -def test_sign_transform(): - assert SIGN(K=1).__repr__() == 'SIGN(K=1)' +def test_sign(): + x = torch.ones(5, 3) edge_index = torch.tensor([ [0, 1, 2, 3, 3, 4], [1, 0, 3, 2, 4, 3], ]) - x = torch.ones(5, 3) - expected_x1 = torch.tensor([[1, 1, 1], [1, 1, 1], [0.7071, 0.7071, 0.7071], - [1.4142, 1.4142, 1.4142], - [0.7071, 0.7071, 0.7071]]) - expected_x2 = torch.ones(5, 3) - data = Data(x=x, edge_index=edge_index, num_nodes=5) + data = Data(x=x, edge_index=edge_index) + transform = SIGN(K=2) - transformed_data = transform(copy.copy(data)) - assert len(transformed_data) == 5 - assert transformed_data.edge_index.tolist() == edge_index.tolist() - assert transformed_data.x.tolist() == x.tolist() - assert torch.allclose(transformed_data.x1, expected_x1, atol=1e-4) - assert torch.allclose(transformed_data.x2, expected_x2, atol=1e-4) + assert str(transform) == 'SIGN(K=2)' + + expected_x1 = torch.tensor([ + [1, 1, 1], + [1, 1, 1], + [0.7071, 0.7071, 0.7071], + [1.4142, 1.4142, 1.4142], + [0.7071, 0.7071, 0.7071], + ]) + expected_x2 = torch.ones(5, 3) + + out = transform(data) + assert len(out) == 4 + assert torch.equal(out.edge_index, edge_index) + assert torch.allclose(out.x, x) + assert torch.allclose(out.x1, expected_x1, atol=1e-4) + assert torch.allclose(out.x2, expected_x2)