Skip to content

Commit

Permalink
add transformation from NHWC to NCHW to compatible with TVM conv2d_tr…
Browse files Browse the repository at this point in the history
…anspose implementation
  • Loading branch information
optima2005 committed Nov 13, 2019
1 parent 52458a7 commit 5acdaee
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 2 deletions.
12 changes: 10 additions & 2 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,16 @@ def _impl(inputs, attr, params):
flip_layout = False

if opname == 'conv_transpose' and attr['data_format'] == 'NHWC':
raise NotImplementedError( \
"conv2d_transpose with NHWC layout is not implemented.")
# transform to NCHW for TVM backend compatible and set 'flip_layout'
# to have output flip back to NHWC
tmp_shape = attr['_input_shapes'][inputs[2]]
tmp_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)]
inputs[2] = _op.transpose(inputs[2], axes=(0, 3, 1, 2))
attr['_input_shapes'][inputs[2]] = tmp_shape
attr['strides'][1], attr['strides'][2], attr['strides'][3] = \
attr['strides'][3], attr['strides'][1], attr['strides'][2]
attr['data_format'] = 'NCHW'
flip_layout = True

inputs_data = inputs[0] if opname != 'conv_transpose' else inputs[2]

Expand Down
9 changes: 9 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,15 @@ def test_forward_convolution():
_test_convolution('depthwise', [4, 17, 17, 124], [1, 1, 124, 1], [1, 1], [1, 1], 'SAME', 'NHWC')
_test_convolution('depthwise', [4, 17, 17, 12], [3, 3, 12, 1], [1, 1], [2, 2], 'VALID', 'NHWC')
_test_convolution('depthwise', [4, 17, 17, 12], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NHWC')
_test_convolution('conv_transpose', [4, 8, 8, 32], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME',
'NHWC', [4, 8, 8, 176])
_test_convolution('conv_transpose', [4, 8, 8, 19], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID',
'NHWC', [4, 17, 17, 19])
_test_convolution('conv_transpose', [4, 17, 17, 19], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME',
'NHWC', [4, 17, 17, 124])
_test_convolution('conv_transpose', [4, 8, 8, 32], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID',
'NHWC', [4, 17, 17, 12])


#######################################################################
# BiasAdd
Expand Down

0 comments on commit 5acdaee

Please sign in to comment.