Skip to content

Commit

Permalink
Fix type parse error about AdaptiveMaxPool (apache#15016)
Browse files Browse the repository at this point in the history
* fix type parse error about max_pool

Fix the bug when the output_size=(3, None).
Crash message: Check failed: (!checked_type.defined()) is false: Expected Array[PrimExpr], but got Array[index 1: relay.Constant]

* add new test case to caputure bug in adaptive_max_pool
  • Loading branch information
jikechao authored and junrushao committed Jun 22, 2023
1 parent d4bfdfd commit 883e3fb
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 0 deletions.
4 changes: 4 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1124,6 +1124,10 @@ def func(x):
def adaptive_max_pool(self, op, inputs, input_types):
data = inputs[0]
output_size = inputs[1]
for i, item in enumerate(output_size):
if isinstance(item, tvm.relay.expr.Constant):
# convert Constant to int
output_size[i] = item.data.numpy()[()]
# returns dummy indices too
return op(data, output_size=output_size), None

Expand Down
7 changes: 7 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,6 +886,8 @@ def test_forward_adaptive_avgpool():

input_data = torch.rand([1, 3, 5, 6]).float()
verify_model(torch.nn.AdaptiveAvgPool2d([3, None]).eval(), input_data=input_data)
input_data = torch.rand([1, 1, 3, 5, 6]).float()
verify_model(torch.nn.AdaptiveAvgPool3d([3, None, None]).eval(), input_data=input_data)


@tvm.testing.uses_gpu
Expand All @@ -901,6 +903,11 @@ def test_forward_adaptive_maxpool():
verify_model(torch.nn.AdaptiveMaxPool1d([1]).eval(), input_data=input_data)
verify_model(torch.nn.AdaptiveMaxPool1d([5]).eval(), input_data=input_data)

input_data = torch.rand([1, 3, 5, 6]).float()
verify_model(torch.nn.AdaptiveMaxPool2d([3, None]).eval(), input_data=input_data)
input_data = torch.rand([1, 1, 3, 5, 6]).float()
verify_model(torch.nn.AdaptiveMaxPool3d([3, None, None]).eval(), input_data=input_data)


@tvm.testing.uses_gpu
def test_forward_maxpool2d():
Expand Down

0 comments on commit 883e3fb

Please sign in to comment.