Skip to content

Commit

Permalink
[QNN][TFLite] Parsing QNN Add op. Adding MobilenetV2.
Browse files Browse the repository at this point in the history
  • Loading branch information
anijain2305 committed Oct 17, 2019
1 parent 02c1e11 commit 1c2acb8
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 2 deletions.
68 changes: 67 additions & 1 deletion python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,20 @@ def has_same_qnn_params(self, lhs_tensor, rhs_tensor):
return lhs_tensor.qnn_params['scale'] == rhs_tensor.qnn_params['scale'] and \
lhs_tensor.qnn_params['zero_point'] == rhs_tensor.qnn_params['zero_point']

def is_qnn_op(self, op):
"""Check if an input tensor is quantized."""
try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")

assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
first_tensor = input_tensors[0]
if first_tensor.qnn_params:
return True
return False

def convert_conv2d(self, op):
"""Convert TFLite conv2d"""
return self.convert_conv(op, "conv2d")
Expand Down Expand Up @@ -498,7 +512,25 @@ def _convert_elemwise(self, relay_op, op):
rhs_type_str = self.get_tensor_type_str(rhs_tensor.tensor.Type())
rhs_expr = self.exp_tab.new_const(self.get_tensor_value(rhs_tensor),
dtype=rhs_type_str)
out = relay_op(lhs_expr, rhs_expr)

output_tensors = self.get_output_tensors(op)
assert len(output_tensors) == 1, "output tensors length should be 1"
output_tensor = output_tensors[0]

# If quantized, extracts qnn params and call QNN add operator.
if lhs_tensor.qnn_params:
assert rhs_tensor.qnn_params, "Both tensors should be quantized."
assert output_tensor.qnn_params, "Output tensor should be quantized."
out = relay_op(lhs=lhs_expr,
rhs=rhs_expr,
lhs_scale=lhs_tensor.qnn_params['scale'],
lhs_zero_point=lhs_tensor.qnn_params['zero_point'],
rhs_scale=rhs_tensor.qnn_params['scale'],
rhs_zero_point=rhs_tensor.qnn_params['zero_point'],
output_scale=output_tensor.qnn_params['scale'],
output_zero_point=output_tensor.qnn_params['zero_point'])
else:
out = relay_op(lhs_expr, rhs_expr)

# Options (fused_activation_function)
options = None
Expand All @@ -517,36 +549,70 @@ def _convert_elemwise(self, relay_op, op):
fused_activation_fn = options.FusedActivationFunction()
# if we have activation fn
if fused_activation_fn != ActivationFunctionType.NONE:
if output_tensor.qnn_params:
raise tvm.error.OpNotImplemented(
'Elemwise operators with fused activation are not supported yet.')
out = self.convert_fused_activation_function(out, fused_activation_fn)

return out

def convert_add(self, op):
"""Convert TFLite ADD"""
# Check if the input tensor is quantized, call QNN op
if self.is_qnn_op(op):
return self._convert_elemwise(_qnn.op.add, op)
return self._convert_elemwise(_op.add, op)

def convert_sub(self, op):
"""Convert TFLite SUB"""
# Check if the input tensor is quantized, call QNN op
if self.is_qnn_op(op):
raise tvm.error.OpNotImplemented(
'TFlite quantized sub operator is not supported yet.')
return self._convert_elemwise(_op.subtract, op)

def convert_mul(self, op):
"""Convert TFLite MUL"""
# Check if the input tensor is quantized, call QNN op
if self.is_qnn_op(op):
raise tvm.error.OpNotImplemented(
'TFlite quantized mul operator is not supported yet.')
return self._convert_elemwise(_op.multiply, op)

def convert_div(self, op):
"""Convert TFLite DIV"""
# Check if the input tensor is quantized, call QNN op
if self.is_qnn_op(op):
raise tvm.error.OpNotImplemented(
'TFlite quantized div operator is not supported yet.')
return self._convert_elemwise(_op.divide, op)

def convert_pow(self, op):
# Check if the input tensor is quantized, call QNN op
if self.is_qnn_op(op):
raise tvm.error.OpNotImplemented(
'TFlite quantized pow operator is not supported yet.')
return self._convert_elemwise(_op.power, op)

def convert_maximum(self, op):
# Check if the input tensor is quantized, call QNN op
if self.is_qnn_op(op):
raise tvm.error.OpNotImplemented(
'TFlite quantized maximum operator is not supported yet.')
return self._convert_elemwise(_op.maximum, op)

def convert_minimum(self, op):
# Check if the input tensor is quantized, call QNN op
if self.is_qnn_op(op):
raise tvm.error.OpNotImplemented(
'TFlite quantized minimum operator is not supported yet.')
return self._convert_elemwise(_op.minimum, op)

def convert_greater(self, op):
# Check if the input tensor is quantized, call QNN op
if self.is_qnn_op(op):
raise tvm.error.OpNotImplemented(
'TFlite quantized greater operator is not supported yet.')
return self._convert_elemwise(_op.greater, op)

def convert_zeros_like(self, op):
Expand Down
22 changes: 21 additions & 1 deletion tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,6 +1037,26 @@ def test_forward_qnn_mobilenet_v1_net():
tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)

def test_forward_qnn_mobilenet_v2_net():
"""Test the Quantized TFLite Mobilenet V2 model."""
# MobilenetV2
tflite_model_file = tf_testing.get_workload_official(
"https://storage.googleapis.com/download.tensorflow.org/models/tflite_11_05_08/mobilenet_v2_1.0_224_quant.tgz",
"mobilenet_v2_1.0_224_quant.tflite")
with open(tflite_model_file, "rb") as f:
tflite_model_buf = f.read()
# Checking the labels because the requantize implementation is different between TFLite and
# Relay. This cause final output numbers to mismatch. So, testing accuracy via labels.
np.random.seed(0)
data = np.random.random_integers(low=0, high=128, size=(1, 224, 224, 3)).astype('uint8')
tflite_output = run_tflite_graph(tflite_model_buf, data)
tflite_predictions = np.squeeze(tflite_output)
tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
tvm_output = run_tvm_graph(tflite_model_buf, data, 'input')
tvm_predictions = np.squeeze(tvm_output)
tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)

#######################################################################
# SSD Mobilenet
# -------------
Expand Down Expand Up @@ -1111,6 +1131,6 @@ def test_forward_ssd_mobilenet_v1():
test_forward_ssd_mobilenet_v1()

# End to End quantized
# TODO - MobilenetV2 fails for now. Remove when fixed.
test_forward_qnn_inception_v1_net()
test_forward_qnn_mobilenet_v1_net()
test_forward_qnn_mobilenet_v2_net()

0 comments on commit 1c2acb8

Please sign in to comment.