Skip to content

Commit

Permalink
fix text on name
Browse files Browse the repository at this point in the history
  • Loading branch information
Padarn committed Jul 29, 2022
1 parent a6e3c85 commit 9a9bd0d
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions test/nn/aggr/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def test_sort_aggregation_pool():
index = torch.tensor([0 for _ in range(N_1)] + [1 for _ in range(N_2)])

aggr = SortAggregation(k=5)
assert str(aggr) == 'SortAggr(k=5)'
assert str(aggr) == 'SortAggregation(k=5)'

out = aggr(x, index)
assert out.size() == (2, 5 * 4)
Expand Down Expand Up @@ -37,7 +37,7 @@ def test_sort_aggregation_pool_smaller_than_k():

# Set k which is bigger than both N_1=4 and N_2=6.
aggr = SortAggregation(k=10)
assert str(aggr) == 'SortAggr(k=10)'
assert str(aggr) == 'SortAggregation(k=10)'

out = aggr(x, index)
assert out.size() == (2, 10 * 4)
Expand Down Expand Up @@ -65,7 +65,7 @@ def test_global_sort_pool_dim_size():
index = torch.tensor([0 for _ in range(N_1)] + [1 for _ in range(N_2)])

aggr = SortAggregation(k=5)
assert str(aggr) == 'SortAggr(k=5)'
assert str(aggr) == 'SortAggregation(k=5)'

# expand batch output by 1
out = aggr(x, index, dim_size=3)
Expand Down

0 comments on commit 9a9bd0d

Please sign in to comment.