From c46b1876f5bba743c864d3fbeeb3d8df4664aa5f Mon Sep 17 00:00:00 2001 From: Alex Wong <11878166+alexwong@users.noreply.github.com> Date: Thu, 25 Feb 2021 21:37:11 -0800 Subject: [PATCH] [Torch] Pool ops, convert strides and pool_size to int (#7517) * Convert strides and pool_size to int * Make helper function, add test * Fix lint --- python/tvm/relay/frontend/pytorch.py | 16 ++++++++++++---- tests/python/frontend/pytorch/test_forward.py | 9 +++++++++ 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 679541051e75..fdebd2f50e68 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -825,11 +825,19 @@ def adaptive_avg_pool_3d(self, inputs, input_types): output_size = inputs[1] return _op.nn.adaptive_avg_pool3d(data, output_size=output_size) + @staticmethod + def convert_const_list(data): + if isinstance(data, list): + for i, _ in enumerate(data): + if isinstance(data[i], _expr.Expr): + data[i] = int(_infer_value_simulated(data[i], {}).asnumpy()) + return data + def maxpool_2d(self, inputs, input_types): data = inputs[0] - pool_size = inputs[1] - strides = inputs[2] if inputs[2] else pool_size + pool_size = self.convert_const_list(inputs[1]) + strides = self.convert_const_list(inputs[2] if inputs[2] else pool_size) padding = inputs[3] dilation = inputs[4] ceil_mode = int(inputs[5]) @@ -1309,8 +1317,8 @@ def softplus(self, inputs, input_types): def avg_pool2d(self, inputs, input_types): data = inputs[0] - pool_size = inputs[1] - strides = inputs[2] if inputs[2] else pool_size + pool_size = self.convert_const_list(inputs[1]) + strides = self.convert_const_list(inputs[2] if inputs[2] else pool_size) padding = inputs[3] ceil_mode = int(inputs[4]) count_include_pad = int(inputs[5]) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 0cf4839c6ebb..90604751d4f1 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -736,7 +736,16 @@ def forward(self, *args): output, indices = self.pool(args[0]) return output + class MaxPool2DWithIntStrides(Module): + def forward(self, *args): + # Makes kernel_size and strides a Relay expr to test converting back to int + x_shape = args[0].shape + kernel_size = [torch.tensor(x_shape[1]).int(), torch.tensor(x_shape[1]).int()] + strides = [torch.tensor(x_shape[0]).int(), torch.tensor(x_shape[0]).int()] + return torch.nn.functional.max_pool2d(args[0], kernel_size=[4, 4], stride=strides) + verify_model(MaxPool2DWithIndices().float().eval(), input_data=input_data) + verify_model(MaxPool2DWithIntStrides().float().eval(), input_data=input_data) @tvm.testing.uses_gpu