From 264098c048a831a3944c8d67ca4721f5d42b5f9d Mon Sep 17 00:00:00 2001 From: wyc-ruiker Date: Tue, 24 Mar 2020 17:28:50 +0000 Subject: [PATCH 1/4] [Torch] Add support for max_pool1d --- python/tvm/relay/frontend/pytorch.py | 14 +++++++++++++ src/relay/op/nn/pooling.cc | 4 ++++ tests/python/frontend/pytorch/test_forward.py | 21 +++++++++++++++++-- 3 files changed, 37 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 83436f25ca85..9d273d123cfd 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -219,6 +219,19 @@ def _impl(inputs, input_types): return _op.nn.max_pool2d(data, pool_size, strides, padding, "NCHW", ceil_mode) return _impl +def _maxpool_1d(): + def _impl(inputs, input_types): + data = inputs[0] + + pool_size = _infer_shape(inputs[1]) + strides = _infer_shape(inputs[2]) + padding = _infer_shape(inputs[3]) + + ceil_mode = int(inputs[5]) + + return _op.nn.max_pool1d(data, pool_size, strides, padding, "NCW", ceil_mode) + return _impl + def _hardtanh(): def _impl(inputs, input_types): a = inputs[0] @@ -863,6 +876,7 @@ def _wrap_const(c): "aten::adaptive_max_pool2d" : _adaptive_max_pool_2d(), "aten::max_pool2d" : _maxpool_2d(), "aten::max_pool2d_with_indices" : _maxpool_2d(), + "aten::max_pool1d" : _maxpool_1d(), "aten::hardtanh" : _hardtanh(), "aten::hardtanh_" : _hardtanh(), "aten::_convolution" : _convolution(), diff --git a/src/relay/op/nn/pooling.cc b/src/relay/op/nn/pooling.cc index 6a2a59b91be0..c20793d9ac28 100644 --- a/src/relay/op/nn/pooling.cc +++ b/src/relay/op/nn/pooling.cc @@ -987,6 +987,10 @@ Array Pool1DCompute(const Attrs& attrs, << " or 4-D input (e.g. NCWc on for vector instructions)" << " or 5-D input (e.g. NCWnc for tensor accelerators)"; + if (param->padding.size() == 1) { + padding.push_back(padding[0]); + } + if (mode == topi::nn::kAvgPool) { bool count_include_pad = reinterpret_cast(param)->count_include_pad; return Array{ diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index e74832a78a72..fa1b46ba8f5d 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -351,7 +351,7 @@ def forward(self, *args): verify_model(AdaptiveAvgPool2D1().float().eval(), input_data=input_data) verify_model(AdaptiveAvgPool2D2().float().eval(), input_data=input_data) -def test_forward_maxpool(): +def test_forward_maxpool2d(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -367,6 +367,22 @@ def forward(self, *args): verify_model(MaxPool2D1().float().eval(), input_data=input_data) verify_model(MaxPool2D2().float().eval(), input_data=input_data) +def test_forward_maxpool1d(): + torch.set_grad_enabled(False) + input_shape = [1, 3, 10] + + class MaxPool1D1(Module): + def forward(self, *args): + return torch.nn.MaxPool1d(kernel_size=1)(args[0]) + + class MaxPool1D2(Module): + def forward(self, *args): + return torch.nn.MaxPool1d(kernel_size=10)(args[0]) + + input_data = torch.rand(input_shape).float() + verify_model(MaxPool1D1().float().eval(), input_data=input_data) + verify_model(MaxPool1D2().float().eval(), input_data=input_data) + def test_forward_avgpool(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -1028,7 +1044,8 @@ def forward(self, xs): test_forward_concatenate() test_forward_relu() test_forward_adaptiveavgpool() - test_forward_maxpool() + test_forward_maxpool2d() + test_forward_maxpool1d() test_forward_hardtanh() test_forward_conv() test_forward_threshold() From 6d5ea20897c74e2c2f65eeb3145f50ca3697eff0 Mon Sep 17 00:00:00 2001 From: wyc-ruiker Date: Tue, 24 Mar 2020 17:55:06 +0000 Subject: [PATCH 2/4] add test --- python/tvm/relay/frontend/pytorch.py | 10 ++++++++-- tests/python/frontend/pytorch/test_forward.py | 10 ++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 9d273d123cfd..fb75bb97de77 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -213,9 +213,12 @@ def _impl(inputs, input_types): pool_size = _infer_shape(inputs[1]) strides = _infer_shape(inputs[2]) padding = _infer_shape(inputs[3]) - + dilation = _infer_shape(inputs[4]) ceil_mode = int(inputs[5]) + if dilation != (1, 1): + raise NotImplementedError("MaxPool2d with dilation %s is not implemented" % (str(dilation), )) + return _op.nn.max_pool2d(data, pool_size, strides, padding, "NCHW", ceil_mode) return _impl @@ -226,9 +229,12 @@ def _impl(inputs, input_types): pool_size = _infer_shape(inputs[1]) strides = _infer_shape(inputs[2]) padding = _infer_shape(inputs[3]) - + dilation = _infer_shape(inputs[4]) ceil_mode = int(inputs[5]) + if dilation != (1,): + raise NotImplementedError("MaxPool1d with dilation %s is not implemented" % (str(dilation), )) + return _op.nn.max_pool1d(data, pool_size, strides, padding, "NCW", ceil_mode) return _impl diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index fa1b46ba8f5d..9dd63f6f8f0d 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -363,9 +363,14 @@ class MaxPool2D2(Module): def forward(self, *args): return torch.nn.MaxPool2d(kernel_size=[10, 10])(args[0]) + class MaxPool2D3(Module): + def forward(self, *args): + return torch.nn.MaxPool2d(kernel_size=[4, 4], padding=2, stride=2)(args[0]) + input_data = torch.rand(input_shape).float() verify_model(MaxPool2D1().float().eval(), input_data=input_data) verify_model(MaxPool2D2().float().eval(), input_data=input_data) + verify_model(MaxPool2D3().float().eval(), input_data=input_data) def test_forward_maxpool1d(): torch.set_grad_enabled(False) @@ -379,9 +384,14 @@ class MaxPool1D2(Module): def forward(self, *args): return torch.nn.MaxPool1d(kernel_size=10)(args[0]) + class MaxPool1D3(Module): + def forward(self, *args): + return torch.nn.MaxPool1d(kernel_size=4, padding=2, stride=2)(args[0]) + input_data = torch.rand(input_shape).float() verify_model(MaxPool1D1().float().eval(), input_data=input_data) verify_model(MaxPool1D2().float().eval(), input_data=input_data) + verify_model(MaxPool1D3().float().eval(), input_data=input_data) def test_forward_avgpool(): torch.set_grad_enabled(False) From 628c1e7b81d6ca8f8cb20eb3a569c00adab637ef Mon Sep 17 00:00:00 2001 From: wyc-ruiker Date: Tue, 24 Mar 2020 18:03:21 +0000 Subject: [PATCH 3/4] fix line-too-long --- python/tvm/relay/frontend/pytorch.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index fb75bb97de77..455c99726f27 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -217,7 +217,8 @@ def _impl(inputs, input_types): ceil_mode = int(inputs[5]) if dilation != (1, 1): - raise NotImplementedError("MaxPool2d with dilation %s is not implemented" % (str(dilation), )) + msg = "MaxPool2d with dilation %s is not implemented" % (str(dilation), ) + raise NotImplementedError(msg) return _op.nn.max_pool2d(data, pool_size, strides, padding, "NCHW", ceil_mode) return _impl @@ -233,7 +234,8 @@ def _impl(inputs, input_types): ceil_mode = int(inputs[5]) if dilation != (1,): - raise NotImplementedError("MaxPool1d with dilation %s is not implemented" % (str(dilation), )) + msg = "MaxPool1d with dilation %s is not implemented" % (str(dilation), ) + raise NotImplementedError(msg) return _op.nn.max_pool1d(data, pool_size, strides, padding, "NCW", ceil_mode) return _impl From 4f90a2d2828d71b001929a26fcf1b42fcde5d157 Mon Sep 17 00:00:00 2001 From: wyc-ruiker Date: Tue, 24 Mar 2020 21:02:20 +0000 Subject: [PATCH 4/4] remove wrapper class --- tests/python/frontend/pytorch/test_forward.py | 50 +++++++------------ 1 file changed, 18 insertions(+), 32 deletions(-) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 9dd63f6f8f0d..1aaaf20774b2 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -354,44 +354,30 @@ def forward(self, *args): def test_forward_maxpool2d(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] - - class MaxPool2D1(Module): - def forward(self, *args): - return torch.nn.MaxPool2d(kernel_size=[1, 1])(args[0]) - - class MaxPool2D2(Module): - def forward(self, *args): - return torch.nn.MaxPool2d(kernel_size=[10, 10])(args[0]) - - class MaxPool2D3(Module): - def forward(self, *args): - return torch.nn.MaxPool2d(kernel_size=[4, 4], padding=2, stride=2)(args[0]) - input_data = torch.rand(input_shape).float() - verify_model(MaxPool2D1().float().eval(), input_data=input_data) - verify_model(MaxPool2D2().float().eval(), input_data=input_data) - verify_model(MaxPool2D3().float().eval(), input_data=input_data) + + verify_model(torch.nn.MaxPool2d(kernel_size=[1, 1]).eval(), + input_data) + verify_model(torch.nn.MaxPool2d(kernel_size=[10, 10]).eval(), + input_data) + verify_model(torch.nn.MaxPool2d(kernel_size=[4, 4], + padding=2, + stride=2).eval(), + input_data) def test_forward_maxpool1d(): torch.set_grad_enabled(False) input_shape = [1, 3, 10] - - class MaxPool1D1(Module): - def forward(self, *args): - return torch.nn.MaxPool1d(kernel_size=1)(args[0]) - - class MaxPool1D2(Module): - def forward(self, *args): - return torch.nn.MaxPool1d(kernel_size=10)(args[0]) - - class MaxPool1D3(Module): - def forward(self, *args): - return torch.nn.MaxPool1d(kernel_size=4, padding=2, stride=2)(args[0]) - input_data = torch.rand(input_shape).float() - verify_model(MaxPool1D1().float().eval(), input_data=input_data) - verify_model(MaxPool1D2().float().eval(), input_data=input_data) - verify_model(MaxPool1D3().float().eval(), input_data=input_data) + + verify_model(torch.nn.MaxPool1d(kernel_size=1).eval(), + input_data) + verify_model(torch.nn.MaxPool1d(kernel_size=10).eval(), + input_data) + verify_model( torch.nn.MaxPool1d(kernel_size=4, + padding=2, + stride=2).eval(), + input_data) def test_forward_avgpool(): torch.set_grad_enabled(False)