From 08e3fef1751ba6e1713ce3c15621acccf4a8928a Mon Sep 17 00:00:00 2001 From: optima2005 Date: Mon, 11 Nov 2019 03:20:27 +0000 Subject: [PATCH] [Relay][Frontend][Tensorflow]Add conv2d_transpose --- python/tvm/relay/frontend/tensorflow.py | 52 ++++++++++++------- .../frontend/tensorflow/test_forward.py | 22 +++++++- 2 files changed, 54 insertions(+), 20 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 0abcb09d6ace5..cd9c8eebdaad8 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -188,10 +188,16 @@ def _impl(inputs, attr, params): attr['data_format'] = attr['data_format'].decode("utf-8") flip_layout = False + if opname == 'conv_transpose' and attr['data_format'] == 'NHWC': + raise NotImplementedError( \ + "conv2d_transpose with NHWC layout is not implemented.") + + inputs_data = inputs[0] if opname != 'conv_transpose' else inputs[2] + # NCHW Layout require weights transpose if attr['data_format'] == 'NCHW': tmp_shape = attr['_input_shapes'][inputs[1]] - if opname == 'conv': + if opname in ['conv', 'conv_transpose']: tmp_shape = [tmp_shape[ii] for ii in (3, 2, 0, 1)] inputs[1] = _op.transpose(inputs[1], axes=(3, 2, 0, 1)) else: @@ -199,13 +205,13 @@ def _impl(inputs, attr, params): inputs[1] = _op.transpose(inputs[1], axes=(2, 3, 0, 1)) attr['_input_shapes'][inputs[1]] = tmp_shape - input_shape = attr['_input_shapes'][inputs[0]] + input_shape = attr['_input_shapes'][inputs_data] weights_shape = attr['_input_shapes'][inputs[1]] if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC": input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)] - inputs[0] = _op.transpose(inputs[0], axes=(0, 3, 1, 2)) - if opname == 'conv': + inputs_data = _op.transpose(inputs_data, axes=(0, 3, 1, 2)) + if opname in ['conv', 'conv_transpose']: weights_shape = [weights_shape[ii] for ii in (3, 2, 0, 1)] inputs[1] = _op.transpose(inputs[1], axes=(3, 2, 0, 1)) else: @@ -221,6 +227,8 @@ def _impl(inputs, attr, params): attr['kernel_shape'] = (weights_shape[0], weights_shape[1]) if opname == 'conv': attr['channels'] = weights_shape[3] + elif opname == 'conv_transpose': + attr['channels'] = weights_shape[2] else: attr['channels'] = input_shape[3] * depth_mult @@ -232,6 +240,8 @@ def _impl(inputs, attr, params): attr['kernel_shape'] = (weights_shape[2], weights_shape[3]) if opname == 'conv': attr['channels'] = weights_shape[0] + elif opname == 'conv_transpose': + attr['channels'] = weights_shape[1] else: attr['channels'] = input_shape[1] * depth_mult if attr['channels'] < 0: @@ -272,17 +282,17 @@ def _impl(inputs, attr, params): if attr['data_format'] == 'NHWC': - inputs[0] = _op.nn.pad(data=inputs[0], - pad_width=((0, 0), - (pad_v[0], pad_v[1]), - (pad_h[0], pad_h[1]), - (0, 0))) + inputs_data = _op.nn.pad(data=inputs_data, + pad_width=((0, 0), + (pad_v[0], pad_v[1]), + (pad_h[0], pad_h[1]), + (0, 0))) else: - inputs[0] = _op.nn.pad(data=inputs[0], - pad_width=((0, 0), - (0, 0), - (pad_v[0], pad_v[1]), - (pad_h[0], pad_h[1]))) + inputs_data = _op.nn.pad(data=inputs_data, + pad_width=((0, 0), + (0, 0), + (pad_v[0], pad_v[1]), + (pad_h[0], pad_h[1]))) attr['padding'] = [0, 0] @@ -292,25 +302,28 @@ def _impl(inputs, attr, params): raise tvm.error.OpAttributeInvalid(msg.format(attr['padding'])) if 'kernel_layout' not in attr: - if opname == 'conv': + if opname in ['conv', 'conv_transpose']: attr['kernel_layout'] = 'HWIO' if attr['data_format'] == 'NHWC' else 'OIHW' else: attr['kernel_layout'] = 'HWOI' if attr['data_format'] == 'NHWC' else 'OIHW' - use_bias = len(inputs) == 3 + use_bias = len(inputs) == (3 if opname != 'conv_transpose' else 4) channel_axis = 1 if attr['data_format'] == "NCHW" else 3 out = AttrCvt( - op_name=_dimension_picker('conv'), + op_name=_dimension_picker('conv', \ + surfix="_transpose" if opname == 'conv_transpose' else ""), transforms={ 'kernel_shape': 'kernel_size', 'data_format': 'data_layout', 'dilations': ('dilation', (0, 0)), 'group': ('groups', 1)}, - custom_check=_dimension_constraint())([inputs[0], inputs[1]], attr) + custom_check=_dimension_constraint())([inputs_data, inputs[1]], attr) if use_bias: - out = _op.nn.bias_add(out, inputs[2], axis=channel_axis) + out = _op.nn.bias_add(out, + inputs[2] if opname != 'conv_transpose' else inputs[3], + axis=channel_axis) if flip_layout: out = _op.transpose(out, axes=(0, 2, 3, 1)) @@ -1385,6 +1398,7 @@ def _impl(inputs, attr, params): 'Concat' : _concat(), 'ConcatV2' : _concatV2(), 'Conv2D' : _conv('conv'), + 'Conv2DBackpropInput' : _conv('conv_transpose'), 'CropAndResize' : _crop_and_resize(), 'DecodeJpeg' : _decode_image(), 'DepthwiseConv2dNative' : _conv('depthwise'), diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index c397d05f62ef0..136c42f499fe6 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -295,7 +295,8 @@ def test_forward_pooling(): def _test_convolution(opname, tensor_in_sizes, filter_in_sizes, - dilations, strides, padding, data_format): + dilations, strides, padding, data_format, + deconv_output_shape=[]): """ One iteration of convolution with given shapes and attributes """ total_size_1 = np.prod(tensor_in_sizes) @@ -326,6 +327,17 @@ def _test_convolution(opname, tensor_in_sizes, filter_in_sizes, compare_tf_with_tvm(np.reshape(data_array, tensor_in_sizes).astype('float32'), 'Placeholder:0', 'Conv2D:0') + elif opname == 'conv_transpose': + nn_ops.conv2d_transpose(in_data, + in_filter, + output_shape=deconv_output_shape, + strides=strides, + dilations=dilations, + padding=padding, + data_format=data_format) + + compare_tf_with_tvm(np.reshape(data_array, tensor_in_sizes).astype('float32'), + 'Placeholder:0', 'conv2d_transpose:0') else: nn_ops.depthwise_conv2d_native(in_data, in_filter, @@ -349,6 +361,14 @@ def test_forward_convolution(): _test_convolution('depthwise', [4, 124, 17, 17], [1, 1, 124, 1], [1, 1], [1, 1], 'SAME', 'NCHW') _test_convolution('depthwise', [4, 12, 17, 17], [3, 3, 12, 1], [1, 1], [2, 2], 'VALID', 'NCHW') _test_convolution('depthwise', [4, 12, 17, 17], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NCHW') + _test_convolution('conv_transpose', [4, 32, 8, 8], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', + 'NCHW', [4, 176, 8, 8]) + _test_convolution('conv_transpose', [4, 19, 8, 8], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', + 'NCHW', [4, 19, 17, 17]) + _test_convolution('conv_transpose', [4, 19, 17, 17], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME', + 'NCHW', [4, 124, 17, 17]) + _test_convolution('conv_transpose', [4, 32, 8, 8], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID', + 'NCHW', [4, 12, 17, 17]) _test_convolution('conv', [4, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', 'NHWC') _test_convolution('conv', [4, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', 'NHWC')