From 9fd73b68f2295accc51b261840188fe229750143 Mon Sep 17 00:00:00 2001 From: Dmitriy Smirnov Date: Fri, 19 Mar 2021 06:47:45 +0000 Subject: [PATCH] [TFLite] Cast operator adapted for MLIR-based convertor (#7639) * [TFLite] Cast operator adapted for MLIR-based convertor Cast operator now can be executed in MLIR-based version. Unit test updated Change-Id: I30e5c1c9d69355116b560af8f6d0582b2d593538 * Comment added Change-Id: I3e2d29ef201283de337168d0b82679b63ca2fcf4 --- python/tvm/relay/frontend/tflite.py | 17 ++++++++++++----- tests/python/frontend/tflite/test_forward.py | 19 ++++++++++++++----- 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index d6f704703cae..a5c9a586e275 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -2336,11 +2336,18 @@ def convert_cast(self, op): input_tensor = input_tensors[0] in_expr = self.get_expr(input_tensor.tensor_idx) - assert op.BuiltinOptionsType() == BuiltinOptions.CastOptions - op_options = op.BuiltinOptions() - cast_options = CastOptions() - cast_options.Init(op_options.Bytes, op_options.Pos) - cast_dtype = cast_options.OutDataType() + # MLIR-based converter outputs no BuiltinOptions for Cast operator. In this + # case the output type can be derived from the Cast operator output tensor. + # When TOCO converter is used there will be "normal" BuiltinOptions.CastOptions + # with output type. + if op.BuiltinOptions() is not None: + assert op.BuiltinOptionsType() == BuiltinOptions.CastOptions + op_options = op.BuiltinOptions() + cast_options = CastOptions() + cast_options.Init(op_options.Bytes, op_options.Pos) + cast_dtype = cast_options.OutDataType() + else: + cast_dtype = self.get_output_tensors(op)[0].tensor.Type() out = _op.cast(in_expr, self.get_tensor_type_str(cast_dtype)) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 0d02c15f2eb8..7c12cd3365ca 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -647,19 +647,28 @@ def test_forward_transpose(): # ---- -def _test_cast(data, cast_dtype): +def _test_cast(data, cast_dtype, use_mlir=False): """ One iteration of CAST """ with tf.Graph().as_default(): in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) out = math_ops.cast(in_data, cast_dtype) - compare_tflite_with_tvm(data, "Placeholder:0", [in_data], [out]) + compare_tflite_with_tvm( + data, "Placeholder:0", [in_data], [out], experimental_new_converter=use_mlir + ) def test_forward_cast(): """ CAST """ - _test_cast(np.arange(6.0, dtype=np.float32).reshape((1, 6)), cast_dtype=tf.int32) - _test_cast(np.arange(6.0, dtype=np.float32).reshape((1, 6)), cast_dtype=tf.uint8) - _test_cast(np.arange(6.0, dtype=np.int32).reshape((1, 6)), cast_dtype=tf.int64) + for use_mlir in [False, True]: + _test_cast( + np.arange(6.0, dtype=np.float32).reshape((1, 6)), cast_dtype=tf.int32, use_mlir=use_mlir + ) + _test_cast( + np.arange(6.0, dtype=np.float32).reshape((1, 6)), cast_dtype=tf.uint8, use_mlir=use_mlir + ) + _test_cast( + np.arange(6.0, dtype=np.int32).reshape((1, 6)), cast_dtype=tf.int64, use_mlir=use_mlir + ) #######################################################################