From 2aac336926bd9b9f68c2ceee1029fe454fcc5cb6 Mon Sep 17 00:00:00 2001 From: Neo Chien Date: Tue, 12 Nov 2019 12:34:50 +0800 Subject: [PATCH] [Relay][Frontend][Tensorflow] Fix type assignment for operator 'tf.range' (#4294) --- python/tvm/relay/frontend/tensorflow.py | 6 ++++-- tests/python/frontend/tensorflow/test_forward.py | 5 +++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 6a24e74636b5..5a17d5f7b8f0 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -1075,6 +1075,7 @@ def _impl(inputs, attr, params): return _impl + def _range(): def _impl(inputs, attr, params): start = _get_param(params, inputs[0])[0] @@ -1082,7 +1083,7 @@ def _impl(inputs, attr, params): if hasattr(inputs[1], "name_hint") or isinstance(inputs[1], _expr.Constant) \ else params.pop('Rank').asnumpy()[0] delta = _get_param(params, inputs[2])[0] - dtype = attr['dtype'].name if 'dtype' in attr else "int32" + dtype = attr['Tidx'].name if 'Tidx' in attr else str(start.dtype) return AttrCvt( op_name="arange", ignores=['Tidx'], @@ -1092,6 +1093,7 @@ def _impl(inputs, attr, params): 'dtype': dtype})([], attr) return _impl + def _elu(): def _impl(inputs, attr, params): dtype = attr['T'].name @@ -1202,7 +1204,7 @@ def _impl(inputs, attr, params): raise tvm.error.OpAttributeInvalid( 'Attribute k must be positive in operator TopKV2') if attr['sorted'] is False: - raise tvm.error.OpAttributeUnimplemented( + raise tvm.error.OpAttributeUnImplemented( 'Attribute sorted=False is not supported in operator TopKV2') return AttrCvt(op_name='topk', ignores=['sorted'], diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 4790af32799c..30b6dfebf8ee 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -1638,6 +1638,11 @@ def test_forward_range(): tf.range(1, 18, 3, name="range") compare_tf_with_tvm([], [], 'range:0') + """test type assignment for operator Range""" + tf.reset_default_graph() + tf.range(1, 256 + 1, 1, dtype=tf.float32) + compare_tf_with_tvm([], [], 'range:0') + ####################################################################### # Pad # ---