Skip to content

Commit

Permalink
[Code Coverage] nn/conv/gen_conv.py (#6703)
Browse files Browse the repository at this point in the history
Code cov improvement for `GENConv`

---------

Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Feb 15, 2023
1 parent 7cf7562 commit 844fc10
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 38 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Properly reset the `data_list` cache of an `InMemoryDataset` when accessing `dataset.data` ([#6685](https://github.com/pyg-team/pytorch_geometric/pull/6685))
- Fixed a bug in `Data.subgraph()` and `HeteroData.subgraph()` ([#6613](https://github.com/pyg-team/pytorch_geometric/pull/6613))
- Fixed a bug in `PNAConv` and `DegreeScalerAggregation` to correctly incorporate degree statistics of isolated nodes ([#6609](https://github.com/pyg-team/pytorch_geometric/pull/6609))
- Improved code coverage ([#6523](https://github.com/pyg-team/pytorch_geometric/pull/6523), [#6538](https://github.com/pyg-team/pytorch_geometric/pull/6538), [#6555](https://github.com/pyg-team/pytorch_geometric/pull/6555), [#6558](https://github.com/pyg-team/pytorch_geometric/pull/6558), [#6568](https://github.com/pyg-team/pytorch_geometric/pull/6568), [#6573](https://github.com/pyg-team/pytorch_geometric/pull/6573), [#6578](https://github.com/pyg-team/pytorch_geometric/pull/6578), [#6597](https://github.com/pyg-team/pytorch_geometric/pull/6597), [#6600](https://github.com/pyg-team/pytorch_geometric/pull/6600), [#6618](https://github.com/pyg-team/pytorch_geometric/pull/6618), [#6619](https://github.com/pyg-team/pytorch_geometric/pull/6619), [#6621](https://github.com/pyg-team/pytorch_geometric/pull/6621), [#6623](https://github.com/pyg-team/pytorch_geometric/pull/6623), [#6637](https://github.com/pyg-team/pytorch_geometric/pull/6637), [#6638](https://github.com/pyg-team/pytorch_geometric/pull/6638), [#6640](https://github.com/pyg-team/pytorch_geometric/pull/6640), [#6645](https://github.com/pyg-team/pytorch_geometric/pull/6645), [#6648](https://github.com/pyg-team/pytorch_geometric/pull/6648), [#6647](https://github.com/pyg-team/pytorch_geometric/pull/6647), [#6657](https://github.com/pyg-team/pytorch_geometric/pull/6657), [#6662](https://github.com/pyg-team/pytorch_geometric/pull/6662), [#6664](https://github.com/pyg-team/pytorch_geometric/pull/6664), [#6667](https://github.com/pyg-team/pytorch_geometric/pull/6667), [#6668](https://github.com/pyg-team/pytorch_geometric/pull/6668), [#6669](https://github.com/pyg-team/pytorch_geometric/pull/6669), [#6670](https://github.com/pyg-team/pytorch_geometric/pull/6670), [#6671](https://github.com/pyg-team/pytorch_geometric/pull/6671), [#6673](https://github.com/pyg-team/pytorch_geometric/pull/6673), [#6675](https://github.com/pyg-team/pytorch_geometric/pull/6675), [#6676](https://github.com/pyg-team/pytorch_geometric/pull/6676), [#6677](https://github.com/pyg-team/pytorch_geometric/pull/6677), [#6678](https://github.com/pyg-team/pytorch_geometric/pull/6678), [#6681](https://github.com/pyg-team/pytorch_geometric/pull/6681), [#6683](https://github.com/pyg-team/pytorch_geometric/pull/6683))
- Improved code coverage ([#6523](https://github.com/pyg-team/pytorch_geometric/pull/6523), [#6538](https://github.com/pyg-team/pytorch_geometric/pull/6538), [#6555](https://github.com/pyg-team/pytorch_geometric/pull/6555), [#6558](https://github.com/pyg-team/pytorch_geometric/pull/6558), [#6568](https://github.com/pyg-team/pytorch_geometric/pull/6568), [#6573](https://github.com/pyg-team/pytorch_geometric/pull/6573), [#6578](https://github.com/pyg-team/pytorch_geometric/pull/6578), [#6597](https://github.com/pyg-team/pytorch_geometric/pull/6597), [#6600](https://github.com/pyg-team/pytorch_geometric/pull/6600), [#6618](https://github.com/pyg-team/pytorch_geometric/pull/6618), [#6619](https://github.com/pyg-team/pytorch_geometric/pull/6619), [#6621](https://github.com/pyg-team/pytorch_geometric/pull/6621), [#6623](https://github.com/pyg-team/pytorch_geometric/pull/6623), [#6637](https://github.com/pyg-team/pytorch_geometric/pull/6637), [#6638](https://github.com/pyg-team/pytorch_geometric/pull/6638), [#6640](https://github.com/pyg-team/pytorch_geometric/pull/6640), [#6645](https://github.com/pyg-team/pytorch_geometric/pull/6645), [#6648](https://github.com/pyg-team/pytorch_geometric/pull/6648), [#6647](https://github.com/pyg-team/pytorch_geometric/pull/6647), [#6657](https://github.com/pyg-team/pytorch_geometric/pull/6657), [#6662](https://github.com/pyg-team/pytorch_geometric/pull/6662), [#6664](https://github.com/pyg-team/pytorch_geometric/pull/6664), [#6667](https://github.com/pyg-team/pytorch_geometric/pull/6667), [#6668](https://github.com/pyg-team/pytorch_geometric/pull/6668), [#6669](https://github.com/pyg-team/pytorch_geometric/pull/6669), [#6670](https://github.com/pyg-team/pytorch_geometric/pull/6670), [#6671](https://github.com/pyg-team/pytorch_geometric/pull/6671), [#6673](https://github.com/pyg-team/pytorch_geometric/pull/6673), [#6675](https://github.com/pyg-team/pytorch_geometric/pull/6675), [#6676](https://github.com/pyg-team/pytorch_geometric/pull/6676), [#6677](https://github.com/pyg-team/pytorch_geometric/pull/6677), [#6678](https://github.com/pyg-team/pytorch_geometric/pull/6678), [#6681](https://github.com/pyg-team/pytorch_geometric/pull/6681), [#6683](https://github.com/pyg-team/pytorch_geometric/pull/6683), [#6703](https://github.com/pyg-team/pytorch_geometric/pull/6703))
- Fixed a bug in which `data.to_heterogeneous()` filtered attributs in the wrong dimension ([#6522](https://github.com/pyg-team/pytorch_geometric/pull/6522))
- Breaking Change: Temporal sampling will now also sample nodes with an equal timestamp to the seed time (requires `pyg-lib>0.1.0`) ([#6517](https://github.com/pyg-team/pytorch_geometric/pull/6517))
- Changed `DataLoader` workers with affinity to start at `cpu0` ([#6512](https://github.com/pyg-team/pytorch_geometric/pull/6512))
Expand Down
72 changes: 35 additions & 37 deletions test/nn/conv/test_gen_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
from torch_geometric.testing import is_full_test


@pytest.mark.parametrize('aggr', ['softmax', 'powermean'])
@pytest.mark.parametrize('aggr', [
'softmax',
'powermean',
['softmax', 'powermean'],
])
def test_gen_conv(aggr):
x1 = torch.randn(4, 16)
x2 = torch.randn(2, 16)
Expand All @@ -16,64 +20,59 @@ def test_gen_conv(aggr):
adj1 = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
adj2 = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4))

conv = GENConv(16, 32, aggr, edge_dim=16)
assert conv.__repr__() == f'GENConv(16, 32, aggr={aggr})'
conv = GENConv(16, 32, aggr, edge_dim=16, msg_norm=True)
assert str(conv) == f'GENConv(16, 32, aggr={aggr})'
out11 = conv(x1, edge_index)
assert out11.size() == (4, 32)
assert torch.allclose(conv(x1, edge_index, size=(4, 4)), out11, atol=1e-6)
assert torch.allclose(conv(x1, adj1.t()), out11, atol=1e-6)
assert torch.allclose(conv(x1, edge_index, size=(4, 4)), out11)
assert torch.allclose(conv(x1, adj1.t()), out11)

out12 = conv(x1, edge_index, value)
assert out12.size() == (4, 32)
assert torch.allclose(conv(x1, edge_index, value, (4, 4)), out12,
atol=1e-6)
assert torch.allclose(conv(x1, adj2.t()), out12, atol=1e-6)
assert torch.allclose(conv(x1, edge_index, value, (4, 4)), out12)
assert torch.allclose(conv(x1, adj2.t()), out12)

if is_full_test():
t = '(Tensor, Tensor, OptTensor, Size) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert torch.allclose(jit(x1, edge_index), out11, atol=1e-6)
assert torch.allclose(jit(x1, edge_index, size=(4, 4)), out11,
atol=1e-6)
assert torch.allclose(jit(x1, edge_index, value), out12, atol=1e-6)
assert torch.allclose(jit(x1, edge_index, value, size=(4, 4)), out12,
atol=1e-6)
assert torch.allclose(jit(x1, edge_index), out11)
assert torch.allclose(jit(x1, edge_index, size=(4, 4)), out11)
assert torch.allclose(jit(x1, edge_index, value), out12)
assert torch.allclose(jit(x1, edge_index, value, size=(4, 4)), out12)

t = '(Tensor, SparseTensor, OptTensor, Size) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert torch.allclose(jit(x1, adj1.t()), out11, atol=1e-6)
assert torch.allclose(jit(x1, adj2.t()), out12, atol=1e-6)
assert torch.allclose(jit(x1, adj1.t()), out11)
assert torch.allclose(jit(x1, adj2.t()), out12)

adj1 = adj1.sparse_resize((4, 2))
adj2 = adj2.sparse_resize((4, 2))

out21 = conv((x1, x2), edge_index)
assert out21.size() == (2, 32)
assert torch.allclose(conv((x1, x2), edge_index, size=(4, 2)), out21,
atol=1e-6)
assert torch.allclose(conv((x1, x2), adj1.t()), out21, atol=1e-6)
assert torch.allclose(conv((x1, x2), edge_index, size=(4, 2)), out21)
assert torch.allclose(conv((x1, x2), adj1.t()), out21)

out22 = conv((x1, x2), edge_index, value)
assert out22.size() == (2, 32)
assert torch.allclose(conv((x1, x2), edge_index, value, (4, 2)), out22,
atol=1e-6)
assert torch.allclose(conv((x1, x2), adj2.t()), out22, atol=1e-6)
assert torch.allclose(conv((x1, x2), edge_index, value, (4, 2)), out22)
assert torch.allclose(conv((x1, x2), adj2.t()), out22)

if is_full_test():
t = '(OptPairTensor, Tensor, OptTensor, Size) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert torch.allclose(jit((x1, x2), edge_index), out21, atol=1e-6)
assert torch.allclose(jit((x1, x2), edge_index, size=(4, 2)), out21,
atol=1e-6)
assert torch.allclose(jit((x1, x2), edge_index, value), out22,
atol=1e-6)
assert torch.allclose(jit((x1, x2), edge_index, value, (4, 2)), out22,
atol=1e-6)
assert torch.allclose(jit((x1, x2), edge_index), out21)
assert torch.allclose(jit((x1, x2), edge_index, size=(4, 2)), out21)
assert torch.allclose(jit((x1, x2), edge_index, value), out22)
assert torch.allclose(jit((x1, x2), edge_index, value, (4, 2)), out22)

t = '(OptPairTensor, SparseTensor, OptTensor, Size) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert torch.allclose(jit((x1, x2), adj1.t()), out21, atol=1e-6)
assert torch.allclose(jit((x1, x2), adj2.t()), out22, atol=1e-6)
assert torch.allclose(jit((x1, x2), adj1.t()), out21)
assert torch.allclose(jit((x1, x2), adj2.t()), out22)

conv.reset_parameters()
assert float(conv.msg_norm.scale) == 1

x1 = torch.randn(4, 8)
x2 = torch.randn(2, 16)
Expand Down Expand Up @@ -104,14 +103,13 @@ def test_gen_conv(aggr):
if is_full_test():
t = '(OptPairTensor, Tensor, OptTensor, Size) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert torch.allclose(jit((x1, x2), edge_index, value), out1,
atol=1e-6)
assert torch.allclose(jit((x1, x2), edge_index, value), out1)
assert torch.allclose(jit((x1, x2), edge_index, value, size=(4, 2)),
out1, atol=1e-6)
out1)
assert torch.allclose(jit((x1, None), edge_index, value, size=(4, 2)),
out2, atol=1e-6)
out2)

t = '(OptPairTensor, SparseTensor, OptTensor, Size) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert torch.allclose(jit((x1, x2), adj.t()), out1, atol=1e-6)
assert torch.allclose(jit((x1, None), adj.t()), out2, atol=1e-6)
assert torch.allclose(jit((x1, x2), adj.t()), out1)
assert torch.allclose(jit((x1, None), adj.t()), out2)

0 comments on commit 844fc10

Please sign in to comment.