From bb9db7f67b733e2f7b4aa243523fbbe68404eec5 Mon Sep 17 00:00:00 2001 From: Wang Yucheng Date: Thu, 16 Jan 2020 00:48:08 +0800 Subject: [PATCH] [Relay][Frontend][TFLite] Add constant input support for elemwise ops (#4666) * [Relay][Frontend][TFLite] Add constant input support for elemwise ops * modify in tflite.py --- python/tvm/relay/frontend/tflite.py | 11 ++++++++++- tests/python/frontend/tflite/test_forward.py | 18 ++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index cb6dbea4f41e..02b8ed980c0b 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -611,7 +611,16 @@ def _convert_elemwise(self, relay_op, op): assert len(input_tensors) == 2, "input tensors length should be 2" lhs_tensor = input_tensors[0] - lhs_expr = self.get_expr(lhs_tensor.tensor_idx) + if self.has_expr(lhs_tensor.tensor_idx): + # In most cases, we can assume that TOCO fuses elemwise operators + # with constants - it means both will be tensors. + lhs_expr = self.get_expr(lhs_tensor.tensor_idx) + else: + # However, in some corner cases, the elemwise operator is not fused, + # we can receive as constant. + lhs_type_str = self.get_tensor_type_str(lhs_tensor.tensor.Type()) + lhs_expr = self.exp_tab.new_const(self.get_tensor_value(lhs_tensor), + dtype=lhs_type_str) rhs_tensor = input_tensors[1] if self.has_expr(rhs_tensor.tensor_idx): diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 1478b2538bba..837f0f611018 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -787,6 +787,24 @@ def _test_elemwise(math_op, data, fused_activation_function=None, quantized=Fals out = with_fused_activation_function(out, fused_activation_function) compare_tflite_with_tvm(data[0], ['in_0:0'], in_data, [out]) + # Test with constant and tensor + with tf.Graph().as_default(): + in_data = [array_ops.placeholder(shape=data[1].shape, dtype='float32', name='in_1')] + + if quantized: + inq_const = tf.quantization.fake_quant_with_min_max_args(data[0], min=-100, max=100, name="const_tensor") + inq_data = [tf.quantization.fake_quant_with_min_max_args(in_data[0], min=-100, max=100, name="inq_1")] + # the 1st tensor is treated as constant and directly added as part of the operation + out = math_op(ops.convert_to_tensor(inq_const, dtype='float32', name='inq_const'), inq_data) + out = with_fused_activation_function(out, fused_activation_function) + out_min, out_max = _test_elemwise_qnn_out_range(qnn_op) + out = tf.quantization.fake_quant_with_min_max_args(out, min=out_min, max=out_max, name="out") + compare_tflite_with_tvm(data[1], ['inq_1:0'], inq_data, [out], quantized=True) + else: + out = math_op(ops.convert_to_tensor(data[0], dtype=data[0].dtype), in_data[0]) + out = with_fused_activation_function(out, fused_activation_function) + compare_tflite_with_tvm(data[1], ['in_1:0'], in_data, [out]) + ####################################################################### # Add # ---