Skip to content

Commit

Permalink
Add test for the qnn_add operator (#4282)
Browse files Browse the repository at this point in the history
* Add test for the qnn_add operator

The tests use fake quant approach so until the tf session tensors remain in float32.
The test data has to be passed in uint8 because of how the tflite/tvm comparison works.
Abs tolerance up to 1 is allowed for the qnn results. For now input_stats are hardcoded
assuming the tests for the other qnn ops will pass the input data in the same range.

* Separate qnn uint8 test function from the fp32 elemwise tests

Isolate qnn uint8 elemwise tests
Remove blank lines
  • Loading branch information
inadob authored and zhiics committed Nov 12, 2019
1 parent dddb0ed commit e680611
Showing 1 changed file with 57 additions and 17 deletions.
74 changes: 57 additions & 17 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def run_tflite_graph(tflite_model_buf, input_data):


def compare_tflite_with_tvm(in_data, in_name, input_tensors,
output_tensors, init_global_variables=False, out_names=None):
output_tensors, init_global_variables=False, out_names=None, quantized=False):
"""Generic function to generate and compare TFLite and TVM output"""
in_data = convert_to_list(in_data)
in_name = convert_to_list(in_name)
Expand All @@ -137,6 +137,17 @@ def compare_tflite_with_tvm(in_data, in_name, input_tensors,
# convert to tflite model
converter = interpreter_wrapper.TFLiteConverter.from_session(
sess, input_tensors, output_tensors)

if quantized:
converter.inference_type = tf.lite.constants.QUANTIZED_UINT8
input_arrays = converter.get_input_arrays()
input_stats = {}
# hardcode the mean_values and std_dev_values (m,s) to be the same for all inputs
# s = 255/(fmax-fmin); m = -fmin*s (the zero point)
for i in input_arrays:
input_stats[i] = (128., 1.275)
converter.quantized_input_stats = input_stats

tflite_model_buffer = converter.convert()
tflite_output = run_tflite_graph(tflite_model_buffer, in_data)

Expand All @@ -148,8 +159,13 @@ def compare_tflite_with_tvm(in_data, in_name, input_tensors,

tvm_output = run_tvm_graph(tflite_model_buffer, in_data, in_node, target=device,
num_output=len(out_names), out_names=out_names)
for i in range(len(tflite_output)):
tvm.testing.assert_allclose(tflite_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)
if quantized:
for i in range(len(tflite_output)):
# allow absolute tolerance of 1 in the quantized results
tvm.testing.assert_allclose(tflite_output[i], tvm_output[i], atol=1, rtol=1e-5)
else:
for i in range(len(tflite_output)):
tvm.testing.assert_allclose(tflite_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)


def with_fused_activation_function(input_tensor, fn_name):
Expand Down Expand Up @@ -545,34 +561,53 @@ def test_forward_concatenation():
# Element-wise
# ---

def _test_elemwise(math_op, data, fused_activation_function=None):
def _test_elemwise(math_op, data, fused_activation_function=None, quantized=False):
""" One iteration of elemwise """

assert len(data) == 2

# Test with two tensors
with tf.Graph().as_default():
in_data = [array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in_0'),
array_ops.placeholder(shape=data[1].shape, dtype=data[1].dtype, name='in_1')]
out = math_op(in_data[0], in_data[1])
out = with_fused_activation_function(out, fused_activation_function)
compare_tflite_with_tvm(data, ['in_0:0', 'in_1:0'], in_data, [out])
in_data = [array_ops.placeholder(shape=data[0].shape, dtype='float32', name='in_0'),
array_ops.placeholder(shape=data[1].shape, dtype='float32', name='in_1')]

if quantized:
# fake_quant will keep the tensors in float32 until the conversion in the session
inq_data = [tf.quantization.fake_quant_with_min_max_args(in_data[0], min=-100, max=100, name="inq_0"),
tf.quantization.fake_quant_with_min_max_args(in_data[1], min=-100, max=100, name="inq_1")]
out = math_op(inq_data[0], inq_data[1])
out = with_fused_activation_function(out, fused_activation_function)
out = tf.quantization.fake_quant_with_min_max_args(out, min=-200, max=200, name="out")
compare_tflite_with_tvm(data, ['inq_0:0', 'inq_1:0'], inq_data, [out], quantized=True)
else:
out = math_op(in_data[0], in_data[1])
out = with_fused_activation_function(out, fused_activation_function)
compare_tflite_with_tvm(data, ['in_0:0', 'in_1:0'], in_data, [out])

# Test with tensor and constant
with tf.Graph().as_default():
in_data = [array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in')]
out = math_op(in_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype))
out = with_fused_activation_function(out, fused_activation_function)
compare_tflite_with_tvm([data[0]], ['in:0'], in_data, [out])

in_data = [array_ops.placeholder(shape=data[0].shape, dtype='float32', name='in_0')]

if quantized:
inq_data = [tf.quantization.fake_quant_with_min_max_args(in_data[0], min=-100, max=100, name="inq_0")]
inq_const = tf.quantization.fake_quant_with_min_max_args(data[1], min=-100, max=100, name="const_tensor")
# the 2nd tensor is treated as constant and directly added as part of the operation
out = math_op(inq_data, ops.convert_to_tensor(inq_const, dtype='float32', name='inq_const'))
out = with_fused_activation_function(out, fused_activation_function)
out = tf.quantization.fake_quant_with_min_max_args(out, min=-200, max=200, name="out")
compare_tflite_with_tvm(data[0], ['inq_0:0'], inq_data, [out], quantized=True)
else:
out = math_op(in_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype))
out = with_fused_activation_function(out, fused_activation_function)
compare_tflite_with_tvm(data[0], ['in_0:0'], in_data, [out])

#######################################################################
# Add
# ---

def _test_add(data, fused_activation_function=None):
def _test_add(data, fused_activation_function=None, quantized=False):
""" One iteration of add """
return _test_elemwise(math_ops.add, data, fused_activation_function)
return _test_elemwise(math_ops.add, data, fused_activation_function, quantized)

#######################################################################
# Subtract
Expand Down Expand Up @@ -627,14 +662,19 @@ def _test_greater(data):
def _test_forward_elemwise(testop):
""" Elewise"""
testop([np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3)),
np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 1, 3))])
np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 1, 3))])
testop([np.arange(6.0, dtype=np.float32).reshape((2, 1, 3)),
np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 3))])
testop([np.arange(3.0, dtype=np.float32).reshape((1, 3)),
np.arange(1.0, 4.0, dtype=np.float32).reshape((1, 3))])

def _test_forward_elemwise_quantized(testop):
testop([np.array(np.random.uniform(0, 255, (3, 6)), dtype=np.uint8),
np.array(np.random.uniform(0, 255, (3, 6)), dtype=np.uint8)], quantized=True)

def test_all_elemwise():
_test_forward_elemwise(_test_add)
_test_forward_elemwise_quantized(_test_add)
_test_forward_elemwise(partial(_test_add, fused_activation_function="RELU"))
_test_forward_elemwise(partial(_test_add, fused_activation_function="RELU6"))
_test_forward_elemwise(_test_sub)
Expand Down

0 comments on commit e680611

Please sign in to comment.