diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index bbbb0a2feaec..48f78837c525 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -1080,6 +1080,15 @@ def _impl(inputs, attr, params): return _impl + +def _prod(): + def _impl(inputs, attr, params): + axis = params.pop(inputs[1].name_hint).asnumpy()[0] + keepdims = attr['keep_dims'] + return _op.prod(inputs[0], int(axis), keepdims=keepdims) + return _impl + + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -1136,6 +1145,7 @@ def _impl(inputs, attr, params): 'Pad' : _pad('Pad'), 'PadV2' : _pad('PadV2'), 'Pow' : _elemwise('power'), + 'Prod' : _prod(), 'Range' : _range(), 'Rank' : _rank(), 'RealDiv' : _elemwise('div'), diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 8dd538aa859c..5ec8564d0b80 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -161,7 +161,6 @@ def is_gpu_available(): else: return False - ####################################################################### # Pooling # ------- @@ -1509,6 +1508,25 @@ def test_forward_expand_dims(): _test_forward_expand_dims(np.array([[1], [2]]), 1) _test_forward_expand_dims(np.array([[1], [2]]), -1) + +####################################################################### +# Prod +# ---- +def _test_forward_reduce_prod(shape, axis, keepdims): + inp_array1 = np.random.uniform(-5, 5, size=shape).astype(np.float32) + with tf.Graph().as_default(): + in1 = tf.placeholder(shape=inp_array1.shape, dtype=inp_array1.dtype) + out = tf.math.reduce_prod(in1, axis, keepdims) + compare_tf_with_tvm(inp_array1, in1.name, out.name) + +def test_forward_reduce_prod(): + _test_forward_reduce_prod((5,), 0, False) + _test_forward_reduce_prod((5, 5), 0, False) + _test_forward_reduce_prod((5, 5), 1, False) + _test_forward_reduce_prod((5,), 0, True) + _test_forward_reduce_prod((5, 5), 0, True) + _test_forward_reduce_prod((5, 5), 1, True) + ####################################################################### # Main # ---- @@ -1550,6 +1568,7 @@ def test_forward_expand_dims(): test_forward_argminmax() test_forward_reduce() test_forward_mean() + test_forward_reduce_prod() # General test_forward_multi_input()