Skip to content

Commit

Permalink
[Torch] Pool ops, convert strides and pool_size to int (#7517)
Browse files Browse the repository at this point in the history
* Convert strides and pool_size to int

* Make helper function, add test

* Fix lint
  • Loading branch information
alexwong authored Feb 26, 2021
1 parent 09b0c8e commit c46b187
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 4 deletions.
16 changes: 12 additions & 4 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,11 +825,19 @@ def adaptive_avg_pool_3d(self, inputs, input_types):
output_size = inputs[1]
return _op.nn.adaptive_avg_pool3d(data, output_size=output_size)

@staticmethod
def convert_const_list(data):
if isinstance(data, list):
for i, _ in enumerate(data):
if isinstance(data[i], _expr.Expr):
data[i] = int(_infer_value_simulated(data[i], {}).asnumpy())
return data

def maxpool_2d(self, inputs, input_types):
data = inputs[0]

pool_size = inputs[1]
strides = inputs[2] if inputs[2] else pool_size
pool_size = self.convert_const_list(inputs[1])
strides = self.convert_const_list(inputs[2] if inputs[2] else pool_size)
padding = inputs[3]
dilation = inputs[4]
ceil_mode = int(inputs[5])
Expand Down Expand Up @@ -1309,8 +1317,8 @@ def softplus(self, inputs, input_types):
def avg_pool2d(self, inputs, input_types):
data = inputs[0]

pool_size = inputs[1]
strides = inputs[2] if inputs[2] else pool_size
pool_size = self.convert_const_list(inputs[1])
strides = self.convert_const_list(inputs[2] if inputs[2] else pool_size)
padding = inputs[3]
ceil_mode = int(inputs[4])
count_include_pad = int(inputs[5])
Expand Down
9 changes: 9 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,7 +736,16 @@ def forward(self, *args):
output, indices = self.pool(args[0])
return output

class MaxPool2DWithIntStrides(Module):
def forward(self, *args):
# Makes kernel_size and strides a Relay expr to test converting back to int
x_shape = args[0].shape
kernel_size = [torch.tensor(x_shape[1]).int(), torch.tensor(x_shape[1]).int()]
strides = [torch.tensor(x_shape[0]).int(), torch.tensor(x_shape[0]).int()]
return torch.nn.functional.max_pool2d(args[0], kernel_size=[4, 4], stride=strides)

verify_model(MaxPool2DWithIndices().float().eval(), input_data=input_data)
verify_model(MaxPool2DWithIntStrides().float().eval(), input_data=input_data)


@tvm.testing.uses_gpu
Expand Down

0 comments on commit c46b187

Please sign in to comment.