Skip to content

Commit

Permalink
[Relay] Fix an adaptive_max_pool1d operator conversion bug (#15386)
Browse files Browse the repository at this point in the history
* Fix an adaptive_max_pool1d operator conversion bug

* Fix an adaptive_max_pool1d operator conversion bug

* add tests for Fix an adaptive_max_pool1d operator conversion bug

* add tests for Fix an adaptive_max_pool1d operator conversion bug

* add tests for Fix an adaptive_max_pool1d operator conversion bug

* add tests for Fix an adaptive_max_pool1d operator conversion bug

* add tests for Fix an adaptive_max_pool1d operator conversion bug

* add tests for Fix an adaptive_max_pool1d operator conversion bug

* Fix an adaptive_max_pool1d operator conversion bug

* Fix an adaptive_max_pool1d operator conversion bug

* Add a TODO

* Add a TODO

* Add a TODO
  • Loading branch information
haoyang9804 authored Sep 4, 2023
1 parent f9e6018 commit d75083c
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
10 changes: 9 additions & 1 deletion python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4291,7 +4291,15 @@ def _handel_nested_input(inputs):

self.current_op.pop()

return [_wrap_const(outputs[ret_name]) for ret_name in ret_names]
# TODO(@haoyang9804): outputs[ret_name] could be None and cause some issue
# revealed by https://github.com/apache/tvm/issues/15004
# Now only adaptive_max_pool1d is considered. Maybe other ops could also
# trigger this problem.
return [
_wrap_const(outputs[ret_name])
for ret_name in ret_names
if ret_name != "aten::adaptive_max_pool1d_0_1"
]

def _set_parameter_source_name(self, op_node, outputs):
"""A helper function to rewrite source_name of parameter."""
Expand Down
10 changes: 10 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3439,6 +3439,16 @@ def forward(self, *args):
verify_model(Full2().float().eval(), input_data=[])


@tvm.testing.uses_gpu
def test_forward_adaptive_max_pool1d():
"""test_forward_adaptive_max_pool1d"""
torch.set_grad_enabled(False)
input_data = [torch.randn([2, 2, 4], dtype=torch.float32)]
m = torch.nn.AdaptiveMaxPool1d(3)

verify_model(m.float().eval(), input_data=input_data)


@tvm.testing.uses_gpu
def test_forward_full_like():
"""test_forward_full_like"""
Expand Down

0 comments on commit d75083c

Please sign in to comment.