Skip to content

Commit

Permalink
Check num_params in to_hetero transformers (#5185)
Browse files Browse the repository at this point in the history
* add test

* typo
  • Loading branch information
rusty1s authored Aug 10, 2022
1 parent c7ef923 commit 119318f
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 2 deletions.
2 changes: 2 additions & 0 deletions test/nn/test_to_hetero_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def test_to_hetero():
assert out[1][('paper', 'cites', 'paper')].size() == (200, 16)
assert out[1][('paper', 'written_by', 'author')].size() == (200, 16)
assert out[1][('author', 'writes', 'paper')].size() == (200, 16)
assert sum(p.numel() for p in model.parameters()) == 1520

for aggr in ['sum', 'mean', 'min', 'max', 'mul']:
model = Net2()
Expand All @@ -184,6 +185,7 @@ def test_to_hetero():
assert isinstance(out, dict) and len(out) == 2
assert out['paper'].size() == (100, 32)
assert out['author'].size() == (100, 32)
assert sum(p.numel() for p in model.parameters()) == 5824

model = Net3()
model = to_hetero(model, metadata, debug=False)
Expand Down
2 changes: 2 additions & 0 deletions test/nn/test_to_hetero_with_bases_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def test_to_hetero_with_bases():
assert out[1][('paper', 'cites', 'paper')].size() == (200, 16)
assert out[1][('paper', 'written_by', 'author')].size() == (200, 16)
assert out[1][('author', 'writes', 'paper')].size() == (200, 16)
assert sum(p.numel() for p in model.parameters()) == 1264

model = Net2()
in_channels = {'x': 16}
Expand All @@ -149,6 +150,7 @@ def test_to_hetero_with_bases():
assert isinstance(out, dict) and len(out) == 2
assert out['paper'].size() == (100, 32)
assert out['author'].size() == (100, 32)
assert sum(p.numel() for p in model.parameters()) == 6076

model = Net3()
in_channels = {'x': 16, 'edge_attr': 8}
Expand Down
2 changes: 0 additions & 2 deletions torch_geometric/nn/to_hetero_with_bases_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,6 @@ def get_node_offset_dict(
input_dict: Dict[NodeType, Union[Tensor, SparseTensor]],
type2id: Dict[NodeType, int],
) -> Dict[NodeType, int]:

cumsum = 0
out: Dict[NodeType, int] = {}
for key in type2id.keys():
Expand All @@ -415,7 +414,6 @@ def get_edge_offset_dict(
input_dict: Dict[EdgeType, Union[Tensor, SparseTensor]],
type2id: Dict[EdgeType, int],
) -> Dict[EdgeType, int]:

cumsum = 0
out: Dict[EdgeType, int] = {}
for key in type2id.keys():
Expand Down

0 comments on commit 119318f

Please sign in to comment.