Skip to content

Commit

Permalink
[TFLite] Fix padding calculation in Transpose Conv (apache#9089)
Browse files Browse the repository at this point in the history
* [TFLite] Fix padding caculation in Transpose Conv

* [TFLite] Fix padding calculation in Transpose Conv

* [TFLite] Fix padding calculation in Transpose Conv

* remove unused variables
  • Loading branch information
euntaik authored and ylc committed Jan 13, 2022
1 parent 6029498 commit 970e0fb
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 3 deletions.
7 changes: 4 additions & 3 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -2858,7 +2858,7 @@ def convert_transpose_conv(self, op):

# Input (data) Tensor. NHWC layout
input_tensor = input_tensors[2]
_, input_h, input_w, input_c = to_int_list(self.get_tensor_shape(input_tensor))
_, _, _, input_c = to_int_list(self.get_tensor_shape(input_tensor))
# Weights tensor. TFLite uses OHWI layout
weights_tensor = input_tensors[1]
out_channels, kernel_h, kernel_w, in_channels = to_int_list(
Expand Down Expand Up @@ -2919,8 +2919,9 @@ def convert_transpose_conv(self, op):
), "Output channel in the filter should match to channel in the output_shape"

if padding == Padding.SAME:
pad_top, pad_bottom = get_pad_value(input_h, kernel_h, stride_h)
pad_left, pad_right = get_pad_value(input_w, kernel_w, stride_w)
output_h, output_w = output_shape_value[1], output_shape_value[2]
pad_top, pad_bottom = get_pad_value(output_h, kernel_h, stride_h)
pad_left, pad_right = get_pad_value(output_w, kernel_w, stride_w)
padding = (pad_top, pad_left, pad_bottom, pad_right)
else:
padding = (0, 0, 0, 0)
Expand Down
20 changes: 20 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1263,6 +1263,26 @@ def _test_transpose_conv(
def test_forward_transpose_conv():
for quantized in [True, False]:
for fp16_quantized in [True, False]:
# odd size input, padding VALID
_test_transpose_conv(
[1, 5, 6, 16],
[2, 2, 16, 16],
[1, 10, 12, 16],
[2, 2],
"VALID",
quantized,
fp16_quantized,
)
# odd size input, padding SAME
_test_transpose_conv(
[1, 5, 6, 16],
[2, 2, 16, 16],
[1, 10, 12, 16],
[2, 2],
"SAME",
quantized,
fp16_quantized,
)
# kernel 3x3, padding VALID
_test_transpose_conv(
[4, 32, 32, 16],
Expand Down

0 comments on commit 970e0fb

Please sign in to comment.