Skip to content

Commit

Permalink
[PYTORCH]ReflectionPad2d op (#5624)
Browse files Browse the repository at this point in the history
  • Loading branch information
siju-samuel authored May 20, 2020
1 parent 78e5aa1 commit 0d1a954
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
18 changes: 18 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1577,6 +1577,22 @@ def _impl(inputs, input_types):
return _impl


def _reflection_pad2d():
def _impl(inputs, input_types):
if isinstance(inputs[1], list):
pad_list = inputs[1]
else:
pad_list = list(_infer_shape(inputs[1]))
padding_left = pad_list[0]
padding_right = pad_list[1]
padding_top = pad_list[2]
padding_bottom = pad_list[3]
paddings = [[0, 0], [0, 0], [padding_top, padding_bottom], [padding_left, padding_right]]

return _op.nn.mirror_pad(inputs[0], paddings, mode='REFLECT')
return _impl


# Helper functions for operator implementation
def _convert_dtype_value(val):
convert_torch_dtype_map = {7:"torch.float64",
Expand Down Expand Up @@ -1695,6 +1711,7 @@ def _get_convert_map(prelude):
"aten::prelu" : _prelu(),
"aten::leaky_relu" : _leaky_relu(),
"aten::elu" : _elu(),
"aten::elu_" : _elu(),
"aten::celu" : _celu(),
"aten::gelu" : _gelu(),
"aten::selu" : _selu(),
Expand Down Expand Up @@ -1798,6 +1815,7 @@ def _get_convert_map(prelude):
"aten::embedding" : _embedding(),
"aten::one_hot" : _one_hot(),
"aten::mm" : _matmul(prelude),
"aten::reflection_pad2d" : _reflection_pad2d(),
"relay::tensor_array_stack" : _tensor_array_stack(prelude),
"aten::add" : _add(prelude),
"aten::add_" : _add(prelude),
Expand Down
10 changes: 10 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,6 +1020,15 @@ def test_adaptive_pool3d():
verify_model(torch.nn.AdaptiveMaxPool3d((7, 8, 9)).eval(), inp)


def test_forward_reflection_pad2d():
inp = torch.rand((1, 1, 3, 3))
verify_model(torch.nn.ReflectionPad2d(2).eval(), inp)
verify_model(torch.nn.ReflectionPad2d((1, 1, 2, 0)).eval(), inp)

inp = torch.rand((2, 4, 5, 6))
verify_model(torch.nn.ReflectionPad2d((1, 3, 2, 4)).eval(), inp)


def test_conv3d():
for ishape in [(1, 32, 16, 16, 16),
(1, 32, 9, 15, 15),
Expand Down Expand Up @@ -2183,6 +2192,7 @@ def forward(self, *args):
test_forward_split()
test_upsample()
test_to()
test_forward_reflection_pad2d()
test_adaptive_pool3d()
test_conv3d()

Expand Down

0 comments on commit 0d1a954

Please sign in to comment.