Skip to content

Commit

Permalink
[Relay][Frontend][TFLite] Add constant input support for elemwise ops (
Browse files Browse the repository at this point in the history
…apache#4666)

* [Relay][Frontend][TFLite] Add constant input support for elemwise ops

* modify in tflite.py
  • Loading branch information
wyc-ruiker authored and zhiics committed Mar 2, 2020
1 parent 76d93db commit bb9db7f
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 1 deletion.
11 changes: 10 additions & 1 deletion python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,16 @@ def _convert_elemwise(self, relay_op, op):
assert len(input_tensors) == 2, "input tensors length should be 2"

lhs_tensor = input_tensors[0]
lhs_expr = self.get_expr(lhs_tensor.tensor_idx)
if self.has_expr(lhs_tensor.tensor_idx):
# In most cases, we can assume that TOCO fuses elemwise operators
# with constants - it means both will be tensors.
lhs_expr = self.get_expr(lhs_tensor.tensor_idx)
else:
# However, in some corner cases, the elemwise operator is not fused,
# we can receive as constant.
lhs_type_str = self.get_tensor_type_str(lhs_tensor.tensor.Type())
lhs_expr = self.exp_tab.new_const(self.get_tensor_value(lhs_tensor),
dtype=lhs_type_str)

rhs_tensor = input_tensors[1]
if self.has_expr(rhs_tensor.tensor_idx):
Expand Down
18 changes: 18 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,24 @@ def _test_elemwise(math_op, data, fused_activation_function=None, quantized=Fals
out = with_fused_activation_function(out, fused_activation_function)
compare_tflite_with_tvm(data[0], ['in_0:0'], in_data, [out])

# Test with constant and tensor
with tf.Graph().as_default():
in_data = [array_ops.placeholder(shape=data[1].shape, dtype='float32', name='in_1')]

if quantized:
inq_const = tf.quantization.fake_quant_with_min_max_args(data[0], min=-100, max=100, name="const_tensor")
inq_data = [tf.quantization.fake_quant_with_min_max_args(in_data[0], min=-100, max=100, name="inq_1")]
# the 1st tensor is treated as constant and directly added as part of the operation
out = math_op(ops.convert_to_tensor(inq_const, dtype='float32', name='inq_const'), inq_data)
out = with_fused_activation_function(out, fused_activation_function)
out_min, out_max = _test_elemwise_qnn_out_range(qnn_op)
out = tf.quantization.fake_quant_with_min_max_args(out, min=out_min, max=out_max, name="out")
compare_tflite_with_tvm(data[1], ['inq_1:0'], inq_data, [out], quantized=True)
else:
out = math_op(ops.convert_to_tensor(data[0], dtype=data[0].dtype), in_data[0])
out = with_fused_activation_function(out, fused_activation_function)
compare_tflite_with_tvm(data[1], ['in_1:0'], in_data, [out])

#######################################################################
# Add
# ---
Expand Down

0 comments on commit bb9db7f

Please sign in to comment.