diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 28703dae3661..cc7cd4830cd4 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1342,10 +1342,31 @@ def _impl(inputs, input_types): def _pad(): def _impl(inputs, input_types): data = inputs[0] - padding = inputs[1] - pad_width = list(zip(padding, padding)) + if isinstance(inputs[1], list): + pad_list = inputs[1] + else: + pad_list = list(_infer_shape(inputs[1])) + + # initialize paddings based on input len + pad_len = len(_infer_shape(data)) * 2 + paddings = [0] * pad_len + + if len(pad_list) >= 2: + paddings[-1] = pad_list[1] + paddings[-2] = pad_list[0] + if len(pad_list) >= 4: + paddings[-3] = pad_list[3] + paddings[-4] = pad_list[2] + if len(pad_list) >= 6: + paddings[-5] = pad_list[5] + paddings[-6] = pad_list[4] + + # group into tuple of 2 ints + paddings = [paddings[i:i + 2] for i in range(0, len(paddings), 2)] + pad_value = inputs[2] - return _op.nn.pad(data, pad_width, pad_value) + + return _op.nn.pad(data, paddings, pad_value) return _impl diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index f1543f02ebd9..85928bfd60c2 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1020,6 +1020,50 @@ def test_adaptive_pool3d(): verify_model(torch.nn.AdaptiveMaxPool3d((7, 8, 9)).eval(), inp) +def test_forward_functional_pad(): + torch.set_grad_enabled(False) + pad = (0, 0) + class Pad1(Module): + def forward(self, *args): + return torch.nn.functional.pad(args[0], pad, "constant", 0) + + input_data = torch.rand((3, 3, 4, 2)) + pad = (1, 1) + verify_model(Pad1().float().eval(), input_data=input_data) + + pad = (1, 1, 2, 2) + verify_model(Pad1().float().eval(), input_data=input_data) + + pad = (0, 1, 2, 1, 3, 3) + verify_model(Pad1().float().eval(), input_data=input_data) + + +def test_forward_zero_pad2d(): + inp = torch.rand((1, 1, 3, 3)) + verify_model(torch.nn.ZeroPad2d(2).eval(), inp) + verify_model(torch.nn.ZeroPad2d((1, 1, 2, 0)).eval(), inp) + + +def test_forward_constant_pad1d(): + inp = torch.rand((1, 2, 4)) + verify_model(torch.nn.ConstantPad2d(2, 3.5).eval(), inp) + + inp = torch.rand((1, 2, 3)) + verify_model(torch.nn.ConstantPad2d((3, 1), 3.5).eval(), inp) + + +def test_forward_constant_pad2d(): + inp = torch.rand((1, 2, 2, 2)) + verify_model(torch.nn.ConstantPad2d(2, 3.5).eval(), inp) + verify_model(torch.nn.ConstantPad2d((3, 0, 2, 1), 3.5).eval(), inp) + + +def test_forward_constant_pad3d(): + inp = torch.rand((1, 3, 2, 2, 2)) + verify_model(torch.nn.ConstantPad3d(3, 3.5).eval(), inp) + verify_model(torch.nn.ConstantPad3d((3, 4, 5, 6, 0, 1), 3.5).eval(), inp) + + def test_forward_reflection_pad2d(): inp = torch.rand((1, 1, 3, 3)) verify_model(torch.nn.ReflectionPad2d(2).eval(), inp) @@ -2200,6 +2244,11 @@ def forward(self, *args): test_upsample() test_forward_upsample3d() test_to() + test_forward_functional_pad() + test_forward_zero_pad2d() + test_forward_constant_pad1d() + test_forward_constant_pad2d() + test_forward_constant_pad3d() test_forward_reflection_pad2d() test_adaptive_pool3d() test_conv3d()