Skip to content

Commit

Permalink
[Frontend][Tensorflow] Support range like axis in tf.raw_ops.All for …
Browse files Browse the repository at this point in the history
…TF 2.x (#7502)

* add TF2.x raw_ops.all axis range support

* apply linting

* fix range() func input
  • Loading branch information
Xingyu Zhou authored Feb 24, 2021
1 parent 88a4fdd commit 7f86987
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 0 deletions.
10 changes: 10 additions & 0 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1976,6 +1976,16 @@ def _impl(inputs, attr, params, mod):
# Symbolic delta
delta = inputs[2]

# if all attributes are constant, evalute the range function and return relay.const
if all(
[
isinstance(start, (np.int32, np.int64, int, np.float32, np.float64, float)),
isinstance(limit, (np.int32, np.int64, int, np.float32, np.float64, float)),
isinstance(delta, (np.int32, np.int64, int, np.float32, np.float64, float)),
]
):
return tvm.relay.const(list(range(int(start), int(limit), int(delta))))

dtype = attr["Tidx"].name if "Tidx" in attr else str(start.dtype)
if isinstance(start, (np.int32, np.int64, int, np.float32, np.float64, float)):
start = _expr.const(start, dtype=dtype)
Expand Down
39 changes: 39 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3948,6 +3948,45 @@ def _test_math_op(op, dtypes=["int32", "float32"]):
_test_math_op(tf.math.reduce_euclidean_norm)


#######################################################################
# All, Max, Min
# ------------------------------------------------------------------


def test_forward_raw_reduce():
def _check_op(tf_op, ishape, axis, keepdims, range_axis=False, dtype="float32"):
tf.reset_default_graph()
if dtype == "bool":
np_data = np.random.choice([True, False], size=ishape)
else:
np_data = np.random.uniform(size=ishape).astype(dtype)
if tf_op == tf.math.reduce_prod:
axis = 1
np_data = np_data.reshape(1, -1)
with tf.Graph().as_default():
if range_axis:
axis = tf.range(axis[0], axis[1], axis[2], name="range", dtype="int32")
in_data = tf.placeholder(dtype, name="in_data")
reduce_op = tf_op(input=in_data, axis=axis, keep_dims=keepdims, name="reduce_std")
compare_tf_with_tvm([np_data], ["in_data:0"], reduce_op.name)

def _test_raw_reduce_op(op, dtypes=["int32", "float32"]):
for dtype in dtypes:
_check_op(op, (3, 10), axis=(-1), keepdims=False, dtype=dtype)
_check_op(op, (8, 16, 32), axis=(-1), keepdims=False, dtype=dtype)
_check_op(op, (1, 8, 8, 3), axis=(2, 3), keepdims=True, dtype=dtype)
_check_op(op, (2, 3, 10, 10), axis=(1, 2), keepdims=True, dtype=dtype)
_check_op(op, (1, 8, 8, 3), axis=(2, 4, 1), keepdims=True, range_axis=True, dtype=dtype)
_check_op(
op, (2, 3, 10, 10), axis=(1, 3, 1), keepdims=True, range_axis=True, dtype=dtype
)

if package_version.parse(tf.VERSION) >= package_version.parse("2.4.1"):
_test_raw_reduce_op(tf.raw_ops.All, dtypes=["bool"])
_test_raw_reduce_op(tf.raw_ops.Max)
_test_raw_reduce_op(tf.raw_ops.Min)


#######################################################################
# Relational operators
# --------------------
Expand Down

0 comments on commit 7f86987

Please sign in to comment.