diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index ac52ab768066..3a3c5fcecd42 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -1976,6 +1976,16 @@ def _impl(inputs, attr, params, mod): # Symbolic delta delta = inputs[2] + # if all attributes are constant, evalute the range function and return relay.const + if all( + [ + isinstance(start, (np.int32, np.int64, int, np.float32, np.float64, float)), + isinstance(limit, (np.int32, np.int64, int, np.float32, np.float64, float)), + isinstance(delta, (np.int32, np.int64, int, np.float32, np.float64, float)), + ] + ): + return tvm.relay.const(list(range(int(start), int(limit), int(delta)))) + dtype = attr["Tidx"].name if "Tidx" in attr else str(start.dtype) if isinstance(start, (np.int32, np.int64, int, np.float32, np.float64, float)): start = _expr.const(start, dtype=dtype) diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index ecf6441bc6b9..d0038caea09f 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -3948,6 +3948,45 @@ def _test_math_op(op, dtypes=["int32", "float32"]): _test_math_op(tf.math.reduce_euclidean_norm) +####################################################################### +# All, Max, Min +# ------------------------------------------------------------------ + + +def test_forward_raw_reduce(): + def _check_op(tf_op, ishape, axis, keepdims, range_axis=False, dtype="float32"): + tf.reset_default_graph() + if dtype == "bool": + np_data = np.random.choice([True, False], size=ishape) + else: + np_data = np.random.uniform(size=ishape).astype(dtype) + if tf_op == tf.math.reduce_prod: + axis = 1 + np_data = np_data.reshape(1, -1) + with tf.Graph().as_default(): + if range_axis: + axis = tf.range(axis[0], axis[1], axis[2], name="range", dtype="int32") + in_data = tf.placeholder(dtype, name="in_data") + reduce_op = tf_op(input=in_data, axis=axis, keep_dims=keepdims, name="reduce_std") + compare_tf_with_tvm([np_data], ["in_data:0"], reduce_op.name) + + def _test_raw_reduce_op(op, dtypes=["int32", "float32"]): + for dtype in dtypes: + _check_op(op, (3, 10), axis=(-1), keepdims=False, dtype=dtype) + _check_op(op, (8, 16, 32), axis=(-1), keepdims=False, dtype=dtype) + _check_op(op, (1, 8, 8, 3), axis=(2, 3), keepdims=True, dtype=dtype) + _check_op(op, (2, 3, 10, 10), axis=(1, 2), keepdims=True, dtype=dtype) + _check_op(op, (1, 8, 8, 3), axis=(2, 4, 1), keepdims=True, range_axis=True, dtype=dtype) + _check_op( + op, (2, 3, 10, 10), axis=(1, 3, 1), keepdims=True, range_axis=True, dtype=dtype + ) + + if package_version.parse(tf.VERSION) >= package_version.parse("2.4.1"): + _test_raw_reduce_op(tf.raw_ops.All, dtypes=["bool"]) + _test_raw_reduce_op(tf.raw_ops.Max) + _test_raw_reduce_op(tf.raw_ops.Min) + + ####################################################################### # Relational operators # --------------------