diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 37f2c239f520..0c5b40cf05d1 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -683,10 +683,10 @@ def _impl(inputs, attr, params): new_input = [] new_input.append(inputs.pop(0)) new_input.append(inputs.pop(0)) - return AttrCvt(op_name="take", - extras={'axis': tvm.const(axis, 'int32')}, - ignores=['Tindices', 'Tparams', 'validate_indices', \ - 'Taxis', '_class'])(new_input, attr) + return AttrCvt(op_name="take", + extras={'axis': tvm.const(axis, 'int32')}, + ignores=['Tindices', 'Tparams', 'validate_indices', \ + 'Taxis', '_class'])(new_input, attr) return _impl def _infer_out_shapes(inputs, params): @@ -818,7 +818,6 @@ def _impl(inputs, attr, params): ignores=['Tpaddings'],)(new_inputs, attr) return _impl - def _transpose(): def _impl(inputs, attr, params): # If perm is not specified, axes is left empty, @@ -831,6 +830,11 @@ def _impl(inputs, attr, params): return _op.transpose(inputs[0], axes=axes) return _impl +def _where(): + def _impl(inputs, attr, params): + return AttrCvt(op_name="where")(inputs, attr) + return _impl + def _rank(): def _impl(inputs, attr, params): input_shape = attr['_input_shapes'][inputs[0]] @@ -1015,6 +1019,7 @@ def _impl(inputs, attr, params): 'DepthwiseConv2dNative' : _conv('depthwise'), 'Shape' : _shape(), 'Sigmoid' : AttrCvt('sigmoid'), + 'Select' : _where(), 'Fill' : _fill(), 'GatherV2' : _gather(), 'Gather' : _gather(), diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 87fe53ec7434..9d26280d470c 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -108,7 +108,6 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False, in_node = [0]*len(in_name) for i in range(len(in_name)): in_node[i] = in_name[i].split(':')[0] if ":" in in_name[i] else in_name[i] - with tf.Session() as sess: if init_global_variables: sess.run(variables.global_variables_initializer()) @@ -483,7 +482,7 @@ def _test_gather(ip_shape, indice_shape, indice_value, axis, dtype): in_data = tf.placeholder(dtype, ip_shape, name="in_data") indices = tf.placeholder("int32", indice_shape, name="indices") tf.gather(in_data, indices, axis=axis) - np_data = np.random.uniform(size=ip_shape).astype(dtype) + np_data = np.random.uniform(1, 10, size=ip_shape).astype(dtype) def _fill_indices(indice_value): indices = np.array(ip_shape, dtype=dtype) @@ -500,14 +499,14 @@ def test_forward_gather(): '''test GatherV2 layer''' _test_gather((4,), (1,), 1, 0, 'int32') _test_gather((4,), (1,), 1, 0, 'float32') - _test_gather((1,4), (1,), [0], 0, 'int32') - _test_gather((4,), (1,2,2), [[[1,0],[0,1]]], 0, 'float32') - _test_gather((2,2), (1,2,2), [[[1,0],[0,1]]], 0, 'int32') - _test_gather((2,2), (1,2,2), [[[1,0],[0,1]]], 1, 'int32') - _test_gather((2,2), (1,2,2), [[[1,0],[0,1]]], 0, 'float32') - _test_gather((3,3,3), (1,1,2), [[[1,0]]], 0, 'int32') - _test_gather((3,3,3), (1,1,2), [[[1,0]]], 2, 'int32') - _test_gather((4,3,5,6), (1,4), [[2,1,0,0]], 0, 'float32') + _test_gather((1, 4), (1,), [0], 0, 'int32') + _test_gather((4,), (1, 2, 2), [[[1, 0],[0, 1]]], 0, 'float32') + _test_gather((2, 2), (1, 2, 2), [[[1, 0],[0, 1]]], 0, 'int32') + _test_gather((2, 2), (1, 2, 2), [[[1, 0],[0, 1]]], 1, 'int32') + _test_gather((2, 2), (1, 2, 2), [[[1, 0],[0, 1]]], 0, 'float32') + _test_gather((3, 3, 3), (1, 1, 2), [[[1, 0]]], 0, 'int32') + _test_gather((3, 3, 3), (1, 1, 2), [[[1, 0]]], 2, 'int32') + _test_gather((4, 3, 5, 6), (1, 4), [[2, 1, 0, 0]], 0, 'float32') def _test_gather_v1(ip_shape, indice_shape, indice_value, dtype): @@ -620,10 +619,10 @@ def _test_unstack(ip_shape, axis, dtype): def test_forward_unstack(): '''test unstack layer''' _test_unstack((6,), 0, 'int32') - _test_unstack((2,6), 1, 'float64') + _test_unstack((2, 6), 1, 'float64') # negative axis - _test_unstack((1,4), -1, 'int32') - _test_unstack((3,6,4), -2, 'float32') + _test_unstack((1, 4), -1, 'int32') + _test_unstack((3, 6, 4), -2, 'float32') ####################################################################### @@ -863,6 +862,22 @@ def test_forward_logical(): test_logical_not() +####################################################################### +# Where, Select +# ------------- +def test_where(): + ''' Where: return elements depending on conditions''' + with tf.Graph().as_default(): + with tf.Session() as sess: + input1 = tf.placeholder(tf.int32, shape=[1, 4, 4, 3], name='input1') + input2 = tf.placeholder(tf.int32, shape=[1, 4, 4, 3], name='input2') + mask = input1 > input2 + tf.where(mask, input1 + 1, input2 * 2) + in_data1 = np.random.uniform(0, 10, size=(1, 4, 4, 3)).astype("uint32") + in_data2 = np.random.uniform(0, 10, size=(1, 4, 4, 3)).astype("uint32") + compare_tf_with_tvm([in_data1, in_data2], ['input1:0', 'input2:0'], 'Select:0') + + ####################################################################### # Inception V3 # ------------ @@ -1299,3 +1314,4 @@ def test_forward_rel_ops(): # Relational ops test_forward_rel_ops() test_forward_logical() + test_where()