Skip to content

Commit

Permalink
[PT FE] Support default strides for avg and max pooling (#17337)
Browse files Browse the repository at this point in the history
* Support default strides for avg and max pooling

* Fix code style

* Remove changes from other ticket
  • Loading branch information
mvafin authored May 4, 2023
1 parent ec90869 commit c1933fc
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 26 deletions.
9 changes: 8 additions & 1 deletion src/frontends/pytorch/src/op/avg_poolnd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,14 @@ OutputVector translate_avg_poolnd(const NodeContext& context) {
num_inputs_check(context, 6, 7);
auto input = context.get_input(0);
auto kernel = context.const_input<Shape>(1);
auto strides = context.const_input<Strides>(2);
Strides strides;
if (!context.input_is_none(2)) {
strides = context.const_input<Strides>(2);
}
if (context.input_is_none(2) || strides.size() == 0) {
// In case strides are not provided default is kernel
strides = kernel;
}
auto pads = context.const_input<Shape>(3); // pytorch supports only symmetric padding
auto rounding_type = context.const_input<bool>(4) ? ov::op::RoundingType::CEIL : ov::op::RoundingType::FLOOR;
auto count_include_pad = context.const_input<bool>(5);
Expand Down
9 changes: 8 additions & 1 deletion src/frontends/pytorch/src/op/max_poolnd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,14 @@ using namespace ov::op;
OutputVector translate_max_poolnd(const NodeContext& context) {
num_inputs_check(context, 6, 6);
auto kernel = context.const_input<Shape>(1);
auto strides = context.const_input<Strides>(2);
Strides strides;
if (!context.input_is_none(2)) {
strides = context.const_input<Strides>(2);
}
if (context.input_is_none(2) || strides.size() == 0) {
// In case strides are not provided default is kernel
strides = kernel;
}
auto pads = context.const_input<Shape>(3); // pytorch supports only symmetric paddings
auto dilations = context.const_input<Strides>(4);
auto rounding_type = context.const_input<bool>(5) ? RoundingType::CEIL : RoundingType::FLOOR;
Expand Down
57 changes: 33 additions & 24 deletions tests/layer_tests/pytorch_tests/test_pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,29 @@

from pytorch_layer_test_class import PytorchLayerTest

d2_avg_params = [{'kernel_size': [3, 3], 'stride': 1, 'padding': 0},
{'kernel_size': [3, 3], 'stride': [1, 1], 'padding': 1},
{'kernel_size': [3, 3], 'stride': [1, 1], 'padding': [0, 1]},
{'kernel_size': [3, 3], 'stride': [1, 1], 'padding': [1, 0]},
{'kernel_size': [3, 3], 'stride': [2, 1], 'padding': 0},
{'kernel_size': [2, 1], 'stride': [2, 1], 'padding': 0},
]

d1_avg_params = [{'kernel_size': 3, 'stride': 1, 'padding': 0},
{'kernel_size': (4,), 'stride': 1, 'padding': 1},
{'kernel_size': 4, 'stride': (5,), 'padding': 2},
]
d3_avg_params = [{'kernel_size': [3, 3, 3], 'stride': 1, 'padding': 0},
{'kernel_size': [3, 3, 3], 'stride': [1, 1, 1], 'padding': 1},
{'kernel_size': [3, 3, 3], 'stride': [3, 3, 3], 'padding': [0, 0, 0]},
{'kernel_size': [3, 2, 1], 'stride': [3, 1, 1], 'padding': [0, 0, 0]},
]
d2_params = [{'kernel_size': [3, 3], 'stride': 1, 'padding': 0},
{'kernel_size': [3, 3], 'stride': [1, 1], 'padding': 1},
{'kernel_size': [3, 3], 'stride': [1, 1], 'padding': [0, 1]},
{'kernel_size': [3, 3], 'stride': [1, 1], 'padding': [1, 0]},
{'kernel_size': [3, 3], 'stride': [2, 1], 'padding': 0},
{'kernel_size': [2, 1], 'stride': [2, 1], 'padding': 0},
{'kernel_size': [2, 1], 'stride': None, 'padding': 0},
{'kernel_size': [2, 1], 'stride': [], 'padding': 0},
]

d1_params = [{'kernel_size': 3, 'stride': 1, 'padding': 0},
{'kernel_size': (4,), 'stride': 1, 'padding': 1},
{'kernel_size': 4, 'stride': (5,), 'padding': 2},
{'kernel_size': 4, 'stride': None, 'padding': 0},
]
d3_params = [{'kernel_size': [3, 3, 3], 'stride': 1, 'padding': 0},
{'kernel_size': [3, 3, 3], 'stride': [1, 1, 1], 'padding': 1},
{'kernel_size': [3, 3, 3], 'stride': [
3, 3, 3], 'padding': [0, 0, 0]},
{'kernel_size': [3, 2, 1], 'stride': [
3, 1, 1], 'padding': [0, 0, 0]},
{'kernel_size': [3, 2, 1], 'stride': None, 'padding': [0, 0, 0]},
]


class TestPooling(PytorchLayerTest):
Expand Down Expand Up @@ -101,7 +107,7 @@ def forward(self, x):

return aten_pooling(), ref_net, f"aten::{op_type}"

@pytest.mark.parametrize("params", d1_avg_params)
@pytest.mark.parametrize("params", d1_params)
@pytest.mark.parametrize("ceil_mode", [True, False])
@pytest.mark.parametrize("count_include_pad", [True, False])
@pytest.mark.nightly
Expand All @@ -111,7 +117,7 @@ def test_avg_pool1d(self, params, ceil_mode, count_include_pad, ie_device, preci
ie_device, precision, ir_version, kwargs_to_prepare_input={'ndim': 3}, trace_model=True,
dynamic_shapes=False)

@pytest.mark.parametrize("params", d2_avg_params)
@pytest.mark.parametrize("params", d2_params)
@pytest.mark.parametrize("ceil_mode", [True, False])
@pytest.mark.parametrize("count_include_pad", [True, False])
@pytest.mark.nightly
Expand All @@ -120,7 +126,7 @@ def test_avg_pool2d(self, params, ceil_mode, count_include_pad, ie_device, preci
self._test(*self.create_model("avg_pool2d", **params, ceil_mode=ceil_mode, count_include_pad=count_include_pad),
ie_device, precision, ir_version, trace_model=True, dynamic_shapes=False)

@pytest.mark.parametrize("params", d3_avg_params)
@pytest.mark.parametrize("params", d3_params)
@pytest.mark.parametrize("ceil_mode", [True, False])
@pytest.mark.parametrize("count_include_pad", [True, False])
@pytest.mark.nightly
Expand All @@ -130,7 +136,7 @@ def test_avg_pool3d(self, params, ceil_mode, count_include_pad, ie_device, preci
ie_device, precision, ir_version, kwargs_to_prepare_input={'ndim': 5}, trace_model=True,
dynamic_shapes=False)

@pytest.mark.parametrize("params", d1_avg_params)
@pytest.mark.parametrize("params", d1_params)
@pytest.mark.parametrize("ceil_mode", [True, False])
@pytest.mark.parametrize("dilation", [1, 2])
@pytest.mark.nightly
Expand All @@ -139,16 +145,19 @@ def test_max_pool1d(self, params, ceil_mode, dilation, ie_device, precision, ir_
self._test(*self.create_model("max_pool1d", **params, ceil_mode=ceil_mode, dilation=dilation),
ie_device, precision, ir_version, kwargs_to_prepare_input={'ndim': 3}, dynamic_shapes=False)

@pytest.mark.parametrize("params", d2_avg_params)
@pytest.mark.parametrize("params", d2_params)
@pytest.mark.parametrize("ceil_mode", [True, False])
@pytest.mark.parametrize("dilation", [1, 2])
@pytest.mark.nightly
@pytest.mark.precommit
def test_max_pool2d(self, params, ceil_mode, dilation, ie_device, precision, ir_version):
to_trace = False
if params["stride"] == []:
to_trace = True
self._test(*self.create_model("max_pool2d", **params, ceil_mode=ceil_mode, dilation=dilation),
ie_device, precision, ir_version, dynamic_shapes=False)
ie_device, precision, ir_version, dynamic_shapes=False, trace_model=to_trace)

@pytest.mark.parametrize("params", d3_avg_params)
@pytest.mark.parametrize("params", d3_params)
@pytest.mark.parametrize("ceil_mode", [True, False])
@pytest.mark.parametrize("dilation", [1, 2])
@pytest.mark.nightly
Expand Down

0 comments on commit c1933fc

Please sign in to comment.