Skip to content

Commit

Permalink
[TFLite] Cast operator adapted for MLIR-based convertor (apache#7639)
Browse files Browse the repository at this point in the history
* [TFLite] Cast operator adapted for MLIR-based convertor

Cast operator now can be executed in MLIR-based version.
Unit test updated

Change-Id: I30e5c1c9d69355116b560af8f6d0582b2d593538

* Comment added

Change-Id: I3e2d29ef201283de337168d0b82679b63ca2fcf4
  • Loading branch information
d-smirnov authored and Trevor Morris committed May 6, 2021
1 parent d9041ce commit 9fd73b6
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 10 deletions.
17 changes: 12 additions & 5 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -2336,11 +2336,18 @@ def convert_cast(self, op):
input_tensor = input_tensors[0]
in_expr = self.get_expr(input_tensor.tensor_idx)

assert op.BuiltinOptionsType() == BuiltinOptions.CastOptions
op_options = op.BuiltinOptions()
cast_options = CastOptions()
cast_options.Init(op_options.Bytes, op_options.Pos)
cast_dtype = cast_options.OutDataType()
# MLIR-based converter outputs no BuiltinOptions for Cast operator. In this
# case the output type can be derived from the Cast operator output tensor.
# When TOCO converter is used there will be "normal" BuiltinOptions.CastOptions
# with output type.
if op.BuiltinOptions() is not None:
assert op.BuiltinOptionsType() == BuiltinOptions.CastOptions
op_options = op.BuiltinOptions()
cast_options = CastOptions()
cast_options.Init(op_options.Bytes, op_options.Pos)
cast_dtype = cast_options.OutDataType()
else:
cast_dtype = self.get_output_tensors(op)[0].tensor.Type()

out = _op.cast(in_expr, self.get_tensor_type_str(cast_dtype))

Expand Down
19 changes: 14 additions & 5 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,19 +647,28 @@ def test_forward_transpose():
# ----


def _test_cast(data, cast_dtype):
def _test_cast(data, cast_dtype, use_mlir=False):
""" One iteration of CAST """
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
out = math_ops.cast(in_data, cast_dtype)
compare_tflite_with_tvm(data, "Placeholder:0", [in_data], [out])
compare_tflite_with_tvm(
data, "Placeholder:0", [in_data], [out], experimental_new_converter=use_mlir
)


def test_forward_cast():
""" CAST """
_test_cast(np.arange(6.0, dtype=np.float32).reshape((1, 6)), cast_dtype=tf.int32)
_test_cast(np.arange(6.0, dtype=np.float32).reshape((1, 6)), cast_dtype=tf.uint8)
_test_cast(np.arange(6.0, dtype=np.int32).reshape((1, 6)), cast_dtype=tf.int64)
for use_mlir in [False, True]:
_test_cast(
np.arange(6.0, dtype=np.float32).reshape((1, 6)), cast_dtype=tf.int32, use_mlir=use_mlir
)
_test_cast(
np.arange(6.0, dtype=np.float32).reshape((1, 6)), cast_dtype=tf.uint8, use_mlir=use_mlir
)
_test_cast(
np.arange(6.0, dtype=np.int32).reshape((1, 6)), cast_dtype=tf.int64, use_mlir=use_mlir
)


#######################################################################
Expand Down

0 comments on commit 9fd73b6

Please sign in to comment.