Skip to content

Commit

Permalink
[PYTORCH]Padding support (apache#5638)
Browse files Browse the repository at this point in the history
  • Loading branch information
siju-samuel authored and trevor-m committed Jun 18, 2020
1 parent 4d3a54f commit 6120708
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 3 deletions.
27 changes: 24 additions & 3 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1342,10 +1342,31 @@ def _impl(inputs, input_types):
def _pad():
def _impl(inputs, input_types):
data = inputs[0]
padding = inputs[1]
pad_width = list(zip(padding, padding))
if isinstance(inputs[1], list):
pad_list = inputs[1]
else:
pad_list = list(_infer_shape(inputs[1]))

# initialize paddings based on input len
pad_len = len(_infer_shape(data)) * 2
paddings = [0] * pad_len

if len(pad_list) >= 2:
paddings[-1] = pad_list[1]
paddings[-2] = pad_list[0]
if len(pad_list) >= 4:
paddings[-3] = pad_list[3]
paddings[-4] = pad_list[2]
if len(pad_list) >= 6:
paddings[-5] = pad_list[5]
paddings[-6] = pad_list[4]

# group into tuple of 2 ints
paddings = [paddings[i:i + 2] for i in range(0, len(paddings), 2)]

pad_value = inputs[2]
return _op.nn.pad(data, pad_width, pad_value)

return _op.nn.pad(data, paddings, pad_value)
return _impl


Expand Down
49 changes: 49 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,6 +1020,50 @@ def test_adaptive_pool3d():
verify_model(torch.nn.AdaptiveMaxPool3d((7, 8, 9)).eval(), inp)


def test_forward_functional_pad():
torch.set_grad_enabled(False)
pad = (0, 0)
class Pad1(Module):
def forward(self, *args):
return torch.nn.functional.pad(args[0], pad, "constant", 0)

input_data = torch.rand((3, 3, 4, 2))
pad = (1, 1)
verify_model(Pad1().float().eval(), input_data=input_data)

pad = (1, 1, 2, 2)
verify_model(Pad1().float().eval(), input_data=input_data)

pad = (0, 1, 2, 1, 3, 3)
verify_model(Pad1().float().eval(), input_data=input_data)


def test_forward_zero_pad2d():
inp = torch.rand((1, 1, 3, 3))
verify_model(torch.nn.ZeroPad2d(2).eval(), inp)
verify_model(torch.nn.ZeroPad2d((1, 1, 2, 0)).eval(), inp)


def test_forward_constant_pad1d():
inp = torch.rand((1, 2, 4))
verify_model(torch.nn.ConstantPad2d(2, 3.5).eval(), inp)

inp = torch.rand((1, 2, 3))
verify_model(torch.nn.ConstantPad2d((3, 1), 3.5).eval(), inp)


def test_forward_constant_pad2d():
inp = torch.rand((1, 2, 2, 2))
verify_model(torch.nn.ConstantPad2d(2, 3.5).eval(), inp)
verify_model(torch.nn.ConstantPad2d((3, 0, 2, 1), 3.5).eval(), inp)


def test_forward_constant_pad3d():
inp = torch.rand((1, 3, 2, 2, 2))
verify_model(torch.nn.ConstantPad3d(3, 3.5).eval(), inp)
verify_model(torch.nn.ConstantPad3d((3, 4, 5, 6, 0, 1), 3.5).eval(), inp)


def test_forward_reflection_pad2d():
inp = torch.rand((1, 1, 3, 3))
verify_model(torch.nn.ReflectionPad2d(2).eval(), inp)
Expand Down Expand Up @@ -2200,6 +2244,11 @@ def forward(self, *args):
test_upsample()
test_forward_upsample3d()
test_to()
test_forward_functional_pad()
test_forward_zero_pad2d()
test_forward_constant_pad1d()
test_forward_constant_pad2d()
test_forward_constant_pad3d()
test_forward_reflection_pad2d()
test_adaptive_pool3d()
test_conv3d()
Expand Down

0 comments on commit 6120708

Please sign in to comment.