From ec7acb923d5257b020a1fe976b3ad7a91ae5fa55 Mon Sep 17 00:00:00 2001 From: Siva Date: Tue, 19 Mar 2019 12:25:07 +0530 Subject: [PATCH] [FRONTEND][TENSORFLOW] Enhance with left over patches from NNVM. (#2757) * [FRONTEND][TENSORFLOW] Enhance with left over patches from NNVM. commit 76188a4 Author: Siva sivar.b@huawei.com [NNVM][TENSORFLOW] bugfix. (#2444) commit 6737739 Author: Ashutosh Parkhi ashutosh.parkhi@imgtec.com [Tensorflow] Support for Crop (#2285) commit f6c3f99 Author: Alexey Romanov alexey.v.romanov@gmail.com [FRONTEND][TENSORFLOW] Use input shapes directly instead of 1-element lists (#2242) commit e5d92e1 Author: Dominic Symes 36929632+dominicsymes@users.noreply.github.com [FRONTEND][TENSORFLOW] Bugfix (#2326) commit 00d509d Author: Alexey Romanov alexey.v.romanov@gmail.com [FRONTEND][TENSORFLOW] Support Unstack and Split (#2105) commit df9d3ad Author: Siva sivar.b@huawei.com [FRONTEND][TENSORFLOW] Bugfix (#2267) commit d1a0c90 Author: Zhebin Jin zhebin.jzb@alibaba-inc.com [FRONTEND][TENSORFLOW]Add Split and realdiv op support (#2123) * Add Split and realdiv op support * Fix the pad calculation in the case of dilated convolution * * review comments * * resnet fix. * * review comments --- .../frontend/tensorflow/test_forward.py | 23 +- python/tvm/relay/frontend/tensorflow.py | 213 +++++++++++++++--- .../frontend/tensorflow/test_forward.py | 121 +++++++++- 3 files changed, 309 insertions(+), 48 deletions(-) diff --git a/nnvm/tests/python/frontend/tensorflow/test_forward.py b/nnvm/tests/python/frontend/tensorflow/test_forward.py index 8a182e7d2334..0a7cfac91bfa 100644 --- a/nnvm/tests/python/frontend/tensorflow/test_forward.py +++ b/nnvm/tests/python/frontend/tensorflow/test_forward.py @@ -137,7 +137,7 @@ def is_gpu_available(): from tensorflow.python.client import device_lib local_device_protos = device_lib.list_local_devices() gpu_list = [x.name for x in local_device_protos if x.device_type == 'GPU'] - if len(gpu_list) < 0: + if len(gpu_list) > 0: print("Tensorflow GPU:", gpu_list) return True else: @@ -168,7 +168,7 @@ def _test_pooling(input_shape, **kwargs): if is_gpu_available(): input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)] - kwargs['data_layout'] = 'NCHW' + kwargs['data_format'] = 'NCHW' _test_pooling_iteration(input_shape, **kwargs) def test_forward_pooling(): @@ -225,8 +225,12 @@ def _test_convolution(tensor_in_sizes, filter_in_sizes, with tf.Graph().as_default(): in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype='float32') in_filter = constant_op.constant(filter_array, shape=filter_in_sizes, dtype='float32') - strides = [1] + strides + [1] - dilations = [1] + dilations + [1] + if data_format == 'NHWC': + strides = [1] + strides + [1] + dilations = [1] + dilations + [1] + else: + strides = [1, 1] + strides + dilations = [1, 1] + dilations nn_ops.conv2d(in_data, in_filter, @@ -899,7 +903,7 @@ def test_forward_mobilenet(): ####################################################################### # ResnetV2 -# --------- +# -------- def test_forward_resnetv2(): '''test resnet model''' if is_gpu_available(): @@ -913,8 +917,13 @@ def test_forward_resnetv2(): with tf.Session() as sess: tf_output = run_tf_graph(sess, data, 'input_tensor:0', out_node + ':0') - tvm_output = run_tvm_graph(graph_def, data, 'input_tensor', tf_output.shape, 'float32') - tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5) + for device in ["llvm", "cuda"]: + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + continue + tvm_output = run_tvm_graph(graph_def, data, 'input_tensor', len(tf_output), target=device) + tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5) ####################################################################### # PTB diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 8d53b003da1e..0efebe3cfec9 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -81,6 +81,7 @@ def __call__(self, inputs, attrs, *args): self._ignores.append('_node_name') self._ignores.append('is_training') self._ignores.append('_target_layout') + self._ignores.append('_input_0d_mismatch') # apply custom check if self._custom_check: @@ -227,7 +228,7 @@ def _impl(inputs, attr, params): attr['data_format'] = attr['data_format'].decode("utf-8") flip_layout = False - input_shape = attr['_input_shapes'][inputs[0]][0] + input_shape = attr['_input_shapes'][inputs[0]] if attr['data_format'] == 'NHWC': attr['kernel_shape'] = (attr['ksize'][1], attr['ksize'][2]) @@ -239,7 +240,7 @@ def _impl(inputs, attr, params): raise TypeError("Unsupported data_format type : {}".format(attr['data_format'])) if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC": - tmp_shape = attr['_input_shapes'][inputs[0]][0] + tmp_shape = attr['_input_shapes'][inputs[0]] input_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)] inputs[0] = _op.transpose(inputs[0], axes=(0, 3, 1, 2)) attr['data_format'] = "NCHW" @@ -292,13 +293,13 @@ def _impl(inputs, attr, params): # NCHW Layout require weights transpose if attr['data_format'] == 'NCHW': - tmp_shape = attr['_input_shapes'][inputs[1]][0] + tmp_shape = attr['_input_shapes'][inputs[1]] tmp_shape = [tmp_shape[ii] for ii in (3, 2, 0, 1)] inputs[1] = _op.transpose(inputs[1], axes=(3, 2, 0, 1)) - attr['_input_shapes'][inputs[1]] = [tmp_shape] + attr['_input_shapes'][inputs[1]] = tmp_shape - input_shape = attr['_input_shapes'][inputs[0]][0] - weights_shape = attr['_input_shapes'][inputs[1]][0] + input_shape = attr['_input_shapes'][inputs[0]] + weights_shape = attr['_input_shapes'][inputs[1]] if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC": input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)] @@ -323,7 +324,7 @@ def _impl(inputs, attr, params): attr['channels'] = input_shape[3] * depth_mult if 'dilations' in attr: - attr['dilations'] = (attr['dilations'][0], attr['dilations'][1]) + attr['dilations'] = (attr['dilations'][1], attr['dilations'][2]) attr['strides'] = (attr['strides'][1], attr['strides'][2]) elif attr['data_format'] == 'NCHW': depth_mult, _, kernel_h, kernel_w = weights_shape @@ -360,8 +361,13 @@ def _impl(inputs, attr, params): in_h = input_shape[2] in_w = input_shape[3] - pad_v = _get_pad_pair(in_h, kernel_h, stride_h) - pad_h = _get_pad_pair(in_w, kernel_w, stride_w) + dilation_h = attr['dilations'][0] + dilation_w = attr['dilations'][1] + dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 + pad_v = _get_pad_pair(in_h, dilated_kernel_h, stride_h) + pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w) + if attr['data_format'] == 'NHWC': inputs[0] = _op.nn.pad(data=inputs[0], @@ -425,8 +431,7 @@ def _impl(inputs, attr, params): dim_input = inputs.pop(1) axis = params[dim_input.name_hint] params.pop(dim_input.name_hint) - return AttrCvt(op_name="expand_dims", ignores=['Tdim'], - extras={'axis': int(axis.asnumpy()[0])})(inputs, attr) + return _expand_dims_0d_aware(inputs[0], attr, axis=axis.asnumpy()[0]) return _impl def _resize_bilinear(): @@ -461,6 +466,11 @@ def _impl(inputs, attr, params): return _impl +def _undef(): + def _impl(inputs, attr, params): + return _sym.__undef__() + return _impl + def _identity(): def _impl(inputs, attr, params): return inputs[0] @@ -489,10 +499,26 @@ def _impl(inputs, attr, params): def _pack(): def _impl(inputs, attr, params): axis = int(attr["axis"]) - inputs_reshaped = [_op.expand_dims(i, axis=axis, num_newaxis=1) for i in inputs] + inputs_reshaped = [_expand_dims_0d_aware(i, attr, axis=axis, num_newaxis=1) for i in inputs] return _op.concatenate(inputs_reshaped, axis) return _impl +def _slice(): + def _impl(inputs, attr, params): + begin = params.pop(_get_name_hint(inputs[1])).asnumpy().tolist() + size = params.pop(_get_name_hint(inputs[2])).asnumpy().tolist() + data_shape = attr['_input_shapes'][inputs[0]] + data_dim = len(data_shape) + end = size + for i in range(data_dim): + if size[i] == -1: + end[i] = data_shape[i] - begin[i] + else: + end[i] += begin[i] + return _op.strided_slice(inputs[0], begin=begin, end=size) + return _impl + + def _reshape(): def _impl(inputs, attr, params): try: @@ -596,7 +622,7 @@ def _impl(inputs, attr, params): def _shape(): def _impl(inputs, attr, params): - return np.array(attr['_input_shapes'][inputs[0]][0], dtype='int32') + return np.array(attr['_input_shapes'][inputs[0]], dtype='int32') return _impl def _fill(): @@ -671,7 +697,7 @@ def _impl(inputs, attr, params): new_axis_mask = int(attr.get('new_axis_mask', 0)) shrink_axis_mask = int(attr.get('shrink_axis_mask', 0)) data_shape = attr['_input_shapes'][inputs[0]] - data_dim = len(data_shape[0]) + data_dim = len(data_shape) stride_dim = len(stride) def _transform_mask(stride_dim, ellipsis_mask): @@ -702,7 +728,7 @@ def _transform_mask(stride_dim, ellipsis_mask): + new_axes_after_ellipsis), data_dim) for i in range(final_index, to_index): m_begin[final_index] = 0 - m_end[final_index] = data_shape[0][final_index] + m_end[final_index] = data_shape[final_index] m_stride[final_index] = 1 fshape_indices.append(final_index) final_index += 1 @@ -712,19 +738,19 @@ def _transform_mask(stride_dim, ellipsis_mask): if final_index == len(m_begin): break if mask & begin_mask: - m_begin[final_index] = data_shape[0][final_index] \ + m_begin[final_index] = data_shape[final_index] \ if stride[index] < 0 else 0 elif begin[index]: m_begin[final_index] = begin[index] if mask & end_mask: m_end[final_index] = 0 if stride[index] < 0 \ - else data_shape[0][final_index] + else data_shape[final_index] elif end[index]: m_end[final_index] = end[index] m_stride[final_index] = stride[index] if mask & shrink_axis_mask: #Tensorflow make axis with shrink_axis_mask as dimension 1 - m_begin[final_index] = data_shape[0][final_index] + begin[index] \ + m_begin[final_index] = data_shape[final_index] + begin[index] \ if begin[index] < 0 else begin[index] m_end[final_index] = begin[index] + 1 m_stride[final_index] = 1 @@ -752,6 +778,9 @@ def _transform_mask(stride_dim, ellipsis_mask): pass else: final_output.append(out_shape[gather_index]) + # Prevent 0-dim tensors which are not accepted by Relay + if not final_output: + final_output.append(1) return _op.reshape(out, newshape=tuple(final_output)) return _impl @@ -789,11 +818,10 @@ def _impl(inputs, attr, params): def _rank(): def _impl(inputs, attr, params): - input_shapes = attr['_input_shapes'][inputs[0]] - assert len(inputs) == 1 + input_shape = attr['_input_shapes'][inputs[0]] name = attr["_node_name"] - params[name] = tvm.nd.array([len(input_shapes[0])]) + params[name] = tvm.nd.array([len(input_shape)]) return [_expr.var(name, shape=params[name].shape, dtype='int32')] @@ -844,6 +872,72 @@ def _impl(inputs, attr, params): )(inputs, attr) return _impl +def _split(has_size_vector): + # TF documentation https://www.tensorflow.org/api_docs/python/tf/split + def _impl(inputs, attr, params): + try: + # order and number of inputs are different: + # if has_size_vector: + # https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/split-v + # else: + # https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/split + + # in addition, `axis` and `num_or_size_splits` can be tensors in TensorFlow, + # we can only support constants + if has_size_vector: + input_node_index = 0 + input_axis_index = 2 + size_splits_input_name = _get_name_hint(inputs[1]) + size_splits = params[size_splits_input_name].asnumpy() + section_beginnings = np.cumsum(size_splits)[:-1] + indices_or_sections = tuple(section_beginnings) + else: + input_node_index = 1 + input_axis_index = 0 + indices_or_sections = attr['num_split'] + input_node = inputs[input_node_index] + axis_input_name = _get_name_hint(inputs[input_axis_index]) + axis_input_value = params[axis_input_name].asnumpy()[0] + except (IndexError, KeyError): + raise TypeError( \ + "Unsupported argument for split: `axis` and `num_or_size_splits` " \ + "should be constants") + return _op.split(input_node, + indices_or_sections=indices_or_sections, + axis=int(axis_input_value)) + return _impl + +def _unpack(): + def _impl(inputs, attr, params): + input_node = inputs[0] + axis = attr['axis'] + input_shape = attr['_input_shapes'][input_node] + axis_length = input_shape[axis] + if axis_length < 0: + raise TypeError("Unstack with unknown axis length") + splitted = _op.split(input_node, + indices_or_sections=axis_length, + axis=axis) + #name=attr.get('_node_name', 'unstack')) + if axis == 0: + axis = None + else: + axis = [axis] + return _expr.TupleWrapper( + _expr.Tuple([_op.squeeze(split_item, axis=axis) \ + for split_item in splitted]), len(splitted)) + return _impl + +def _expand_dims_0d_aware(data, attr, axis, num_newaxis=1): + if data in attr['_input_0d_mismatch']: + return data if num_newaxis == 1 else \ + AttrCvt(op_name="expand_dims", ignores=['Tdim', 'N'], + extras={'axis': int(axis), 'num_newaxis': int(num_newaxis-1)})([data], attr) + + return AttrCvt(op_name="expand_dims", ignores=['Tdim', 'N'], + extras={'axis': int(axis), 'num_newaxis': int(num_newaxis)})([data], attr) + + def _softmax(): def _impl(inputs, attr, params): return AttrCvt(op_name='softmax', @@ -885,11 +979,13 @@ def _impl(inputs, attr, params): 'Add' : _elemwise('add'), 'Sub' : _elemwise('subtract'), 'Mul' : _elemwise('multiply'), + 'RealDiv' : _elemwise('div'), 'Maximum' : _elemwise('maximum'), 'Minimum' : _elemwise('minimum'), 'Sum' : _sum(), 'Square' : _square(), 'Pack' : _pack(), + 'Slice' : _slice(), 'LeakyRelu' : AttrCvt('leaky_relu'), 'Relu' : AttrCvt('relu'), 'Reshape' : _reshape(), @@ -924,6 +1020,9 @@ def _impl(inputs, attr, params): 'GreaterEqual' : _broadcast('greater_equal'), 'Equal' : _broadcast('equal'), 'NotEqual' : _broadcast('not_equal'), + 'Split' : _split(False), + 'SplitV' : _split(True), + 'Unpack' : _unpack(), } def _LSTMBlockCell(): @@ -958,8 +1057,8 @@ def _impl(inputs, in_state_c, in_state_h, attr, params): forget_bias = attr.pop('forget_bias') input_shape = attr['_input_shapes'][inputs[0]] weight_shape = attr['_input_shapes'][inputs[3]] - batch_size, input_size = input_shape[0][0], input_shape[0][1] - num_hidden_layers = weight_shape[0][1] + batch_size, input_size = input_shape[0], input_shape[1] + num_hidden_layers = weight_shape[1] num_hidden = num_hidden_layers // 4 in_data = _op.reshape(in_data, @@ -1087,8 +1186,8 @@ def _LSTMBlockCellWrapper(inputs, attr, params, input_shape = attr['_input_shapes'][inputs[0]] weight_shape = attr['_input_shapes'][inputs[3]] - batch_size = input_shape[0][0] - num_hidden = weight_shape[0][1] // 4 + batch_size = input_shape[0] + num_hidden = weight_shape[1] // 4 if layer == 0: #Create initial states placeholder in case of first layer @@ -1183,6 +1282,8 @@ def __init__(self): self._output_shapes = {} self._num_param = 0 self._num_rnn_layer = False + self._outputs_are_0d = {} + self._input_shapes = {} def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): """Construct relay nodes from tensorflow graph definition - GraphDef. @@ -1259,6 +1360,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): # Operator name 'Const' is treated as a parameter to build params dict. input_shapes = {} + input_0d_mismatch = set() attr = self._parse_attr(node.attr) # Variable converted to Const will not have only value attr @@ -1267,6 +1369,8 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): elif shape and node.name in shape: # Give priority to user argument. self._output_shapes[node.name] = [shape[node.name]] + elif node.op == 'Placeholder': + self._output_shapes[node.name] = [self._input_shapes[node.name]] elif '_output_shapes' in attr: self._output_shapes[node.name] = \ [tensor_util.TensorShapeProtoToList(tshape) \ @@ -1274,8 +1378,13 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): else: # Keep the list indexable to avoid key error. # Actual value will be filled after node creation. + # Will infer shapes if the graph is not frozen with add_shapes=True self._output_shapes[node.name] = [None] + self._outputs_are_0d[node.name] = [ \ + not shape if isinstance(tshape, list) else False \ + for tshape in self._output_shapes[node.name]] + if node.op == "Placeholder": self._output_shapes[node.name] = [self._input_shapes[node.name]] self._nodes[node.name] = [_expr.var(node.name, @@ -1315,10 +1424,33 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): # Fill shapes for all inputs in a list inputs = [] for i in node.input: - if i in self._nodes: - inputs.append(self._nodes[i][0]) - input_shapes[self._nodes[i][0]] = self._output_shapes[i] + # Some TensorFlow operators internally maintain execution layers + # and their output name includes the layer number along with + # graph node name. E.g. the node name is 'Model/RNN/cell_0/RnnCell', but the + # output tensor name is 'Model/RNN/cell_0/RnnCell:0'. In this case, + # the number has to be ignored for single-output nodes. + # On the other hand, for multi-output nodes the number is the output index, + # and the lack of the number implies 0. + tensor_name = i.split(':') + node_name = tensor_name[0] + if node_name in self._nodes: + in_sym = self._nodes[node_name] + if isinstance(in_sym, _expr.TupleWrapper): + tensor_slot = int(tensor_name[1]) if len(tensor_name) > 1 else 0 + in_sym = [in_sym[tensor_slot]] + input_shape = self._output_shapes[node_name][tensor_slot] + else: + tensor_slot = 0 + input_shape = self._output_shapes[node_name][0] + inputs.append(in_sym[0]) + input_shapes[in_sym[0]] = input_shape + # This means the node is 1d in Relay and 0d in TF. + # See `_expand_dims_0d_aware`. + if self._outputs_are_0d[node_name][tensor_slot] and input_shape: + input_0d_mismatch.add(in_sym) + attr['_input_shapes'] = input_shapes + attr['_input_0d_mismatch'] = input_0d_mismatch op = self._convert_operator(node.op, inputs, attr, graph) @@ -1340,23 +1472,36 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): # Infer shapes even without specifying "add_shapes=True" if output_shapes == [None]: - out_type = ir_pass.infer_type(self._nodes[node.name][0]) - self._output_shapes[node.name] = [get_const_tuple(out_type.checked_type.shape)] + out_shapes = [] + for node_item in self._nodes[node.name]: + out_type = ir_pass.infer_type(node_item) + out_shapes.append(get_const_tuple(out_type.checked_type.shape)) + self._output_shapes[node.name] = out_shapes if self._output_shapes[node.name] and shape and node.name in shape: assert self._output_shapes[node.name] == list(shape[node.name]) # Infer shapes if passed explicitely node_output = self._nodes[node.name] - out_type = ir_pass.infer_type(node_output[0]) - self._output_shapes[node.name] = [get_const_tuple(out_type.checked_type.shape)] - + if shape and (not self._output_shapes[node.name][0] + or -1 in self._output_shapes[node.name][0]): + out_shapes = [] + for node_item in node_output: + out_type = ir_pass.infer_type(node_item) + out_shapes.append(get_const_tuple(out_type.checked_type.shape)) + self._output_shapes[node.name] = out_shapes out = [] if outputs is None: out = op else: - out = [self._nodes[out_name][0] for out_name in outputs] + for out_name in outputs: + if ":" in out_name: + out_name, out_num = out_name.split(":") + out_num = int(out_num) + out.append(self._nodes[out_name][out_num]) + else: + out.append(self._nodes[out_name][0]) #Add the RNN outputs also with 'head' nodes of the relay graph if self._num_rnn_layer: diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index eae06ead71b6..10368ea3d9ab 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -127,7 +127,8 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False, if no_gpu and device == 'cuda': continue - tvm_output = run_tvm_graph(final_graph_def, in_data, in_node, target=device) + tvm_output = run_tvm_graph(final_graph_def, in_data, in_node, target=device, + out_names=out_name, num_output=len(out_name)) # since the names from tensorflow and relay runs are not exactly same, # first len(tf_output) will be compared for i in range(len(tf_output)): @@ -170,7 +171,7 @@ def _test_pooling(input_shape, **kwargs): if is_gpu_available(): input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)] - kwargs['data_layout'] = 'NCHW' + kwargs['data_format'] = 'NCHW' _test_pooling_iteration(input_shape, **kwargs) def test_forward_pooling(): @@ -227,8 +228,12 @@ def _test_convolution(tensor_in_sizes, filter_in_sizes, with tf.Graph().as_default(): in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype='float32') in_filter = constant_op.constant(filter_array, shape=filter_in_sizes, dtype='float32') - strides = [1] + strides + [1] - dilations = [1] + dilations + [1] + if data_format == 'NHWC': + strides = [1] + strides + [1] + dilations = [1] + dilations + [1] + else: + strides = [1, 1] + strides + dilations = [1, 1] + dilations nn_ops.conv2d(in_data, in_filter, @@ -504,6 +509,84 @@ def test_forward_gather(): _test_gather((3,3,3), (1,1,2), [[[1,0]]], 2, 'int32') _test_gather((4,3,5,6), (1,4), [[2,1,0,0]], 0, 'float32') +####################################################################### +# Split +# ----- + +def _test_split(in_shape, axis, num_or_size_splits, dtype): + np_data = np.random.uniform(-5, 5, size=in_shape).astype(dtype) + + """ One iteration of a Split """ + tf.reset_default_graph() + in_data = tf.placeholder(dtype, in_shape, name="in_data") + num_split = len(num_or_size_splits) if isinstance(num_or_size_splits, list) else num_or_size_splits + tf.split(in_data, num_or_size_splits, axis=axis) + + compare_tf_with_tvm([np_data], ['in_data:0'], [f'split:{n}' for n in range(num_split)]) + + # and now test together with concat + tf.reset_default_graph() + in_data = tf.placeholder(dtype, in_shape, name="in_data") + splitted = tf.split(in_data, num_or_size_splits, axis=axis) + tf.concat(splitted, axis) + + compare_tf_with_tvm([np_data], 'in_data:0', 'concat:0') + +def test_forward_split(): + '''test split layer''' + # rank 1 + _test_split((3,), 0, 1, 'float32') + _test_split((3,), 0, 3, 'float32') + _test_split((6,), 0, 3, 'float32') + # rank 2 + _test_split((6, 2), 0, 3, 'float32') + _test_split((2, 6), 1, 6, 'float32') + # rank 3 + _test_split((6, 2, 4), 0, 2, 'int32') + _test_split((2, 6, 4), 1, 3, 'float32') + _test_split((2, 4, 6), 2, 1, 'float32') + # rank 4 + _test_split((6, 1, 3, 5), 0, 3, 'float32') + _test_split((1, 6, 3, 5), 1, 3, 'float32') + _test_split((1, 3, 6, 5), 2, 3, 'float32') + _test_split((1, 3, 5, 6), 3, 3, 'float32') + # split along negative axis + _test_split((6, 1, 3, 5), -4, 3, 'float32') + _test_split((1, 6, 3, 5), -3, 3, 'float32') + _test_split((1, 3, 6, 5), -2, 3, 'float32') + _test_split((1, 3, 5, 6), -1, 3, 'float32') + # size_splits list + _test_split((6,), 0, [1, 2, 3], 'int32') + _test_split((3, 6, 4), -2, [1, 4, 1], 'float32') + + +####################################################################### +# Unstack +# ------- + +def _test_unstack(ip_shape, axis, dtype): + np_data = np.random.uniform(-5, 5, size=ip_shape).astype(dtype) + + tf.reset_default_graph() + in_data = tf.placeholder(dtype, ip_shape, name="in_data") + tf.unstack(in_data, axis=axis) + + compare_tf_with_tvm([np_data], ['in_data:0'], [f'unstack:{n}' for n in range(ip_shape[axis])]) + + tf.reset_default_graph() + in_data = tf.placeholder(dtype, ip_shape, name="in_data") + tf.stack(tf.unstack(in_data, axis=axis), axis=axis) + + compare_tf_with_tvm([np_data], ['in_data:0'], 'stack:0') + +def test_forward_unstack(): + '''test unstack layer''' + _test_unstack((6,), 0, 'int32') + _test_unstack((2,6), 1, 'float64') + # negative axis + _test_unstack((1,4), -1, 'int32') + _test_unstack((3,6,4), -2, 'float32') + ####################################################################### # Multi Input to graph @@ -576,6 +659,22 @@ def test_forward_resize_bilinear(): _test_resize_bilinear((4, 16, 32, 32), [50, 50], False) _test_resize_bilinear((6, 32, 64, 64), [20, 20], True) +####################################################################### +# Crop to bounding box +# -------------------- + +def _test_crop(in_shape, off_h, off_w, tar_h, tar_w): + """ Crop to bounding box """ + data = np.random.uniform(size=in_shape).astype('float32') + with tf.Graph().as_default(): + in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) + tf.image.crop_to_bounding_box(in_data, off_h, off_w, tar_h, tar_w) + compare_tf_with_tvm(data, 'Placeholder:0', 'crop_to_bounding_box/Slice:0') + +def test_forward_crop(): + """ Crop to bounding box """ + _test_crop((1, 224, 224, 3), 20, 20, 120, 120) + ####################################################################### # LSTM @@ -804,7 +903,7 @@ def test_forward_mobilenet(): ####################################################################### # ResnetV2 -# --------- +# -------- def test_forward_resnetv2(): '''test resnet model''' if is_gpu_available(): @@ -818,8 +917,13 @@ def test_forward_resnetv2(): with tf.Session() as sess: tf_output = run_tf_graph(sess, data, 'input_tensor:0', out_node + ':0') - tvm_output = run_tvm_graph(graph_def, data, 'input_tensor', tf_output.shape, 'float32') - tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5) + for device in ["llvm", "cuda"]: + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + continue + tvm_output = run_tvm_graph(graph_def, data, 'input_tensor', len(tf_output), target=device) + tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5) ####################################################################### # PTB @@ -1106,9 +1210,12 @@ def test_forward_rel_ops(): test_forward_squeeze() test_forward_pack() test_forward_resize_bilinear() + test_forward_crop() test_forward_pad() test_forward_gather() test_forward_stridedslice() + test_forward_split() + test_forward_unstack() # Activations test_forward_sigmoid()