Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend][TFLite] Dynamically calculate input_stats of any fake_quant range #4789

Merged
merged 2 commits into from
Feb 5, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 27 additions & 16 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ def run_tflite_graph(tflite_model_buf, input_data):


def compare_tflite_with_tvm(in_data, in_name, input_tensors,
output_tensors, init_global_variables=False, out_names=None, quantized=False):
output_tensors, init_global_variables=False,
out_names=None, quantized=False, input_range=None):
"""Generic function to generate and compare TFLite and TVM output"""
in_data = convert_to_list(in_data)
in_name = convert_to_list(in_name)
Expand All @@ -143,11 +144,16 @@ def compare_tflite_with_tvm(in_data, in_name, input_tensors,
converter.inference_type = tf.lite.constants.QUANTIZED_UINT8
input_arrays = converter.get_input_arrays()
input_stats = {}
# hardcode the mean_values and std_dev_values (m,s) to be the same
# if all inputs are in (float_min; float_max) == (-100, 100)
# calculate the mean and quantization scale for every input tensor,
# with respect to its fp32 input range, defined in fake_quant.
# s = 255/(fmax-fmin); m = -fmin*s (the zero point)
for i in input_arrays:
input_stats[i] = (128., 1.275)
try:
quant_scale = 255 / (input_range[i][1] - input_range[i][0])
except ZeroDivisionError:
raise ZeroDivisionError('Min and max of the input range for tensor ' + i + ' can\'t be equal')
mean = - input_range[i][0] * quant_scale
input_stats[i] = (mean, quant_scale)
converter.quantized_input_stats = input_stats

tflite_model_buffer = converter.convert()
Expand Down Expand Up @@ -757,13 +763,15 @@ def _test_elemwise(math_op, data, fused_activation_function=None, quantized=Fals
if quantized:
# fake_quant will keep the tensors in float32 until the conversion in the session
inq_data = [tf.quantization.fake_quant_with_min_max_args(in_data[0], min=-100, max=100, name="inq_0"),
tf.quantization.fake_quant_with_min_max_args(in_data[1], min=-100, max=100, name="inq_1")]
tf.quantization.fake_quant_with_min_max_args(in_data[1], min=-50, max=50, name="inq_1")]
input_range = {'inq_0': (-100, 100), 'inq_1': (-50, 50)}
out = math_op(inq_data[0], inq_data[1])
out = with_fused_activation_function(out, fused_activation_function)
# set the quantized output range with respect to the operation
# set the fp32 output range with respect to the operation
out_min, out_max = _test_elemwise_qnn_out_range(qnn_op)
out = tf.quantization.fake_quant_with_min_max_args(out, min=out_min, max=out_max, name="out")
compare_tflite_with_tvm(data, ['inq_0:0', 'inq_1:0'], inq_data, [out], quantized=True)
compare_tflite_with_tvm(data, ['inq_0:0', 'inq_1:0'], inq_data, [out],
quantized=True, input_range=input_range)
else:
out = math_op(in_data[0], in_data[1])
out = with_fused_activation_function(out, fused_activation_function)
Expand All @@ -775,13 +783,14 @@ def _test_elemwise(math_op, data, fused_activation_function=None, quantized=Fals

if quantized:
inq_data = [tf.quantization.fake_quant_with_min_max_args(in_data[0], min=-100, max=100, name="inq_0")]
inq_const = tf.quantization.fake_quant_with_min_max_args(data[1], min=-100, max=100, name="const_tensor")
inq_const = tf.quantization.fake_quant_with_min_max_args(data[1], min=-50, max=50, name="const_tensor")
input_range = {'inq_0': (-100, 100)}
# the 2nd tensor is treated as constant and directly added as part of the operation
out = math_op(inq_data, ops.convert_to_tensor(inq_const, dtype='float32', name='inq_const'))
out = with_fused_activation_function(out, fused_activation_function)
out_min, out_max = _test_elemwise_qnn_out_range(qnn_op)
out = tf.quantization.fake_quant_with_min_max_args(out, min=out_min, max=out_max, name="out")
compare_tflite_with_tvm(data[0], ['inq_0:0'], inq_data, [out], quantized=True)
compare_tflite_with_tvm(data[0], ['inq_0:0'], inq_data, [out], quantized=True, input_range=input_range)
else:
out = math_op(in_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype))
out = with_fused_activation_function(out, fused_activation_function)
Expand All @@ -793,13 +802,14 @@ def _test_elemwise(math_op, data, fused_activation_function=None, quantized=Fals

if quantized:
inq_const = tf.quantization.fake_quant_with_min_max_args(data[0], min=-100, max=100, name="const_tensor")
inq_data = [tf.quantization.fake_quant_with_min_max_args(in_data[0], min=-100, max=100, name="inq_1")]
inq_data = [tf.quantization.fake_quant_with_min_max_args(in_data[0], min=-50, max=50, name="inq_1")]
input_range = {'inq_1': (-50, 50)}
# the 1st tensor is treated as constant and directly added as part of the operation
out = math_op(ops.convert_to_tensor(inq_const, dtype='float32', name='inq_const'), inq_data)
out = with_fused_activation_function(out, fused_activation_function)
out_min, out_max = _test_elemwise_qnn_out_range(qnn_op)
out = tf.quantization.fake_quant_with_min_max_args(out, min=out_min, max=out_max, name="out")
compare_tflite_with_tvm(data[1], ['inq_1:0'], inq_data, [out], quantized=True)
compare_tflite_with_tvm(data[1], ['inq_1:0'], inq_data, [out], quantized=True, input_range=input_range)
else:
out = math_op(ops.convert_to_tensor(data[0], dtype=data[0].dtype), in_data[0])
out = with_fused_activation_function(out, fused_activation_function)
Expand Down Expand Up @@ -886,11 +896,11 @@ def _test_forward_elemwise_quantized(testop):
np.array(np.random.uniform(0, 255, (3, 6)), dtype=np.uint8)], quantized=True, qnn_op=testop)

def _test_elemwise_qnn_out_range(qnn_op):
# set the fake_quant output range if input tensors are in [-100, 100] float32
# set the fake_quant output range with respect to the input tensors float32 range
qnn_out_range = {
_test_add: (-200, 200),
_test_sub: (-200, 200),
_test_mul: (-1e+4, 1e+4),
_test_add: (-150, 150),
_test_sub: (-150, 150),
_test_mul: (-5e+3, 5e+3),
}

return qnn_out_range[qnn_op]
Expand Down Expand Up @@ -955,9 +965,10 @@ def _test_reduce_quantize(math_op, data, keep_dims=None):
with tf.Graph().as_default():
in_data = [array_ops.placeholder(shape=data[0].shape, dtype="float32", name='in')]
inq_data = [tf.quantization.fake_quant_with_min_max_args(in_data[0], min=-100, max=100, name="inq_0")]
input_range = {'inq_0': (-100, 100)}
out = math_op(inq_data, data[1], keep_dims)
out = tf.quantization.fake_quant_with_min_max_args(out, min=-200, max=200, name="out")
compare_tflite_with_tvm([data[0]], ['inq_0:0'], [inq_data[0]], [out], quantized=True)
compare_tflite_with_tvm([data[0]], ['inq_0:0'], [inq_data[0]], [out], quantized=True, input_range=input_range)


#######################################################################
Expand Down