Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch] Add support for max_pool1d #5142

Merged
merged 4 commits into from
Mar 24, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,12 +213,33 @@ def _impl(inputs, input_types):
pool_size = _infer_shape(inputs[1])
strides = _infer_shape(inputs[2])
padding = _infer_shape(inputs[3])

dilation = _infer_shape(inputs[4])
Copy link
Member

@masahi masahi Mar 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does pooling have dilation argument?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In https://pytorch.org/docs/stable/nn.html#maxpool1d, MaxPool1d and MaxPool2d have dilation argument.

ceil_mode = int(inputs[5])

if dilation != (1, 1):
msg = "MaxPool2d with dilation %s is not implemented" % (str(dilation), )
raise NotImplementedError(msg)

return _op.nn.max_pool2d(data, pool_size, strides, padding, "NCHW", ceil_mode)
return _impl

def _maxpool_1d():
def _impl(inputs, input_types):
data = inputs[0]

pool_size = _infer_shape(inputs[1])
strides = _infer_shape(inputs[2])
padding = _infer_shape(inputs[3])
dilation = _infer_shape(inputs[4])
ceil_mode = int(inputs[5])

if dilation != (1,):
msg = "MaxPool1d with dilation %s is not implemented" % (str(dilation), )
raise NotImplementedError(msg)

return _op.nn.max_pool1d(data, pool_size, strides, padding, "NCW", ceil_mode)
return _impl

def _hardtanh():
def _impl(inputs, input_types):
a = inputs[0]
Expand Down Expand Up @@ -863,6 +884,7 @@ def _wrap_const(c):
"aten::adaptive_max_pool2d" : _adaptive_max_pool_2d(),
"aten::max_pool2d" : _maxpool_2d(),
"aten::max_pool2d_with_indices" : _maxpool_2d(),
"aten::max_pool1d" : _maxpool_1d(),
"aten::hardtanh" : _hardtanh(),
"aten::hardtanh_" : _hardtanh(),
"aten::_convolution" : _convolution(),
Expand Down
4 changes: 4 additions & 0 deletions src/relay/op/nn/pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -987,6 +987,10 @@ Array<te::Tensor> Pool1DCompute(const Attrs& attrs,
<< " or 4-D input (e.g. NCWc on for vector instructions)"
<< " or 5-D input (e.g. NCWnc for tensor accelerators)";

if (param->padding.size() == 1) {
padding.push_back(padding[0]);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this? Does 1D pooling require two pad values (left & right)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}

if (mode == topi::nn::kAvgPool) {
bool count_include_pad = reinterpret_cast<const AvgPool1DAttrs*>(param)->count_include_pad;
return Array<te::Tensor>{
Expand Down
31 changes: 29 additions & 2 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def forward(self, *args):
verify_model(AdaptiveAvgPool2D1().float().eval(), input_data=input_data)
verify_model(AdaptiveAvgPool2D2().float().eval(), input_data=input_data)

def test_forward_maxpool():
def test_forward_maxpool2d():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]

Expand All @@ -363,9 +363,35 @@ class MaxPool2D2(Module):
def forward(self, *args):
return torch.nn.MaxPool2d(kernel_size=[10, 10])(args[0])

class MaxPool2D3(Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this wrapper

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we need a test for padding and stride in Maxpool?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am talking about the same point as https://github.com/apache/incubator-tvm/pull/5142/files#r397046667

Yes, you should have a test, but no need to write a wrapper class

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am talking about the same point as https://github.com/apache/incubator-tvm/pull/5142/files#r397046667

Yes, you should have a test, but no need to write a wrapper class

Thanks. I misunderstood your review suggestion :-)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

def forward(self, *args):
return torch.nn.MaxPool2d(kernel_size=[4, 4], padding=2, stride=2)(args[0])

input_data = torch.rand(input_shape).float()
verify_model(MaxPool2D1().float().eval(), input_data=input_data)
verify_model(MaxPool2D2().float().eval(), input_data=input_data)
verify_model(MaxPool2D3().float().eval(), input_data=input_data)

def test_forward_maxpool1d():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10]

class MaxPool1D1(Module):
def forward(self, *args):
return torch.nn.MaxPool1d(kernel_size=1)(args[0])

class MaxPool1D2(Module):
def forward(self, *args):
return torch.nn.MaxPool1d(kernel_size=10)(args[0])

class MaxPool1D3(Module):
def forward(self, *args):
return torch.nn.MaxPool1d(kernel_size=4, padding=2, stride=2)(args[0])

input_data = torch.rand(input_shape).float()
verify_model(MaxPool1D1().float().eval(), input_data=input_data)
verify_model(MaxPool1D2().float().eval(), input_data=input_data)
verify_model(MaxPool1D3().float().eval(), input_data=input_data)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.


def test_forward_avgpool():
torch.set_grad_enabled(False)
Expand Down Expand Up @@ -1028,7 +1054,8 @@ def forward(self, xs):
test_forward_concatenate()
test_forward_relu()
test_forward_adaptiveavgpool()
test_forward_maxpool()
test_forward_maxpool2d()
test_forward_maxpool1d()
test_forward_hardtanh()
test_forward_conv()
test_forward_threshold()
Expand Down