From 5607ab401efeed930ef9f8faebb1241219c48ce0 Mon Sep 17 00:00:00 2001 From: cchung100m Date: Sat, 9 Nov 2019 21:23:01 +0800 Subject: [PATCH] [Relay][Frontend][Tensorflow] Fix type assignment for operator 'tf.range' --- python/tvm/relay/frontend/tensorflow.py | 6 ++++-- tests/python/frontend/tensorflow/test_forward.py | 5 +++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 0abcb09d6ace..2f0a25b8dd1e 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -1067,6 +1067,7 @@ def _impl(inputs, attr, params): return _impl + def _range(): def _impl(inputs, attr, params): start = _get_param(params, inputs[0])[0] @@ -1074,7 +1075,7 @@ def _impl(inputs, attr, params): if hasattr(inputs[1], "name_hint") or isinstance(inputs[1], _expr.Constant) \ else params.pop('Rank').asnumpy()[0] delta = _get_param(params, inputs[2])[0] - dtype = attr['dtype'].name if 'dtype' in attr else "int32" + dtype = attr['Tidx'].name if 'Tidx' in attr else str(start.dtype) return AttrCvt( op_name="arange", ignores=['Tidx'], @@ -1084,6 +1085,7 @@ def _impl(inputs, attr, params): 'dtype': dtype})([], attr) return _impl + def _elu(): def _impl(inputs, attr, params): dtype = attr['T'].name @@ -1194,7 +1196,7 @@ def _impl(inputs, attr, params): raise tvm.error.OpAttributeInvalid( 'Attribute k must be positive in operator TopKV2') if attr['sorted'] is False: - raise tvm.error.OpAttributeUnimplemented( + raise tvm.error.OpAttributeUnImplemented( 'Attribute sorted=False is not supported in operator TopKV2') return AttrCvt(op_name='topk', ignores=['sorted'], diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 0d2748062c08..75e3701c8049 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -1617,6 +1617,11 @@ def test_forward_range(): tf.range(1, 18, 3, name="range") compare_tf_with_tvm([], [], 'range:0') + """test type assignment for operator Range""" + tf.reset_default_graph() + tf.range(1, 256 + 1, 1, dtype=tf.float32) + compare_tf_with_tvm([], [], 'range:0') + ####################################################################### # Pad # ---