From d9185644e0d8ea91fc23c4c98e439141284028d5 Mon Sep 17 00:00:00 2001 From: ANSHUMAN TRIPATHY Date: Tue, 27 Oct 2020 23:30:08 +0530 Subject: [PATCH 1/2] TFLite failures resulted from TF latest version upgrade resolved --- docker/install/ubuntu_install_tflite.sh | 6 +- python/tvm/relay/frontend/tflite.py | 15 ++- tests/python/frontend/tflite/test_forward.py | 115 ++++++++++--------- 3 files changed, 77 insertions(+), 59 deletions(-) diff --git a/docker/install/ubuntu_install_tflite.sh b/docker/install/ubuntu_install_tflite.sh index 123ff520d725..2dfbb0681a80 100755 --- a/docker/install/ubuntu_install_tflite.sh +++ b/docker/install/ubuntu_install_tflite.sh @@ -33,14 +33,14 @@ pip3 install flatbuffers # Build the TFLite static library, necessary for building with TFLite ON. # The library is built at: # tensorflow/tensorflow/lite/tools/make/gen/*/lib/libtensorflow-lite.a. -git clone https://github.com/tensorflow/tensorflow --branch=r2.1 +git clone https://github.com/tensorflow/tensorflow --branch=r2.3 ./tensorflow/tensorflow/lite/tools/make/download_dependencies.sh ./tensorflow/tensorflow/lite/tools/make/build_lib.sh # Setup tflite from schema mkdir tflite cd tflite -wget -q https://raw.githubusercontent.com/tensorflow/tensorflow/r2.1/tensorflow/lite/schema/schema.fbs +wget -q https://raw.githubusercontent.com/tensorflow/tensorflow/r2.3/tensorflow/lite/schema/schema.fbs flatc --python schema.fbs cat <setup.py @@ -48,7 +48,7 @@ import setuptools setuptools.setup( name="tflite", - version="2.1.0", + version="2.3.1", author="google", author_email="google@google.com", description="TFLite", diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index f52c318c8e97..6da06ac4a20b 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -2770,7 +2770,7 @@ def convert_transpose_conv(self, op): raise ImportError("The tflite package must be installed") input_tensors = self.get_input_tensors(op) - assert len(input_tensors) == 3, "input tensors length should be 3" + assert len(input_tensors) >= 3, "input tensors length should be >= 3" # Input (data) Tensor. NHWC layout input_tensor = input_tensors[2] @@ -2843,6 +2843,19 @@ def convert_transpose_conv(self, op): out_dtype=output_tensor_type_str, ) + # if we have bias + if len(input_tensors) == 4: + bias_tensor = input_tensors[3] + bias_tensor_type = bias_tensor.tensor.Type() + # bias tensor type should be INT32 (quantization) or FLOAT32 + assert bias_tensor_type in (TensorType.INT32, TensorType.FLOAT32) + bias_tensor_type_str = self.get_tensor_type_str(bias_tensor_type) + bias_expr = self.exp_tab.new_const( + self.get_tensor_value(bias_tensor), dtype=bias_tensor_type_str + ) + channel_axis = 3 + out = _op.nn.bias_add(out, bias_expr, axis=channel_axis) + return out def convert_quantize(self, op): diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index caa41806c8aa..58c1e3d69707 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -136,14 +136,20 @@ def vmobj_to_list(o): raise RuntimeError("Unknown object type: %s" % type(o)) -def _quantize_keras_model(keras_model, representative_data_gen): +def _quantize_keras_model( + keras_model, representative_data_gen, is_float_input=False, is_float_output=False +): """Utility function to quantize a Keras model using TFLite converter.""" converter = interpreter_wrapper.TFLiteConverter.from_keras_model(keras_model) converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE] converter.representative_dataset = representative_data_gen converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] - converter.inference_input_type = tf.uint8 - converter.inference_output_type = tf.uint8 + # NOTE: If representative dataset is provided, and inference input type is not set, + # then converter will self add quant & dequant Op accordingly. + if not is_float_input: + converter.inference_input_type = tf.uint8 + if not is_float_output: + converter.inference_output_type = tf.uint8 return converter.convert() @@ -973,6 +979,7 @@ def _test_convolution( [out], quantized=quantized, input_range=input_range, + experimental_new_converter=True, ) else: # Quantized the inputs and feed them to the convolution @@ -1000,6 +1007,7 @@ def _test_convolution( [out], quantized=quantized, input_range=input_range, + experimental_new_converter=True, ) else: data_array = np.reshape(data_array, tensor_in_sizes).astype("float32") @@ -1078,18 +1086,18 @@ def test_forward_convolution(): ) # TFLite2 quantized convolution testing - if package_version.parse(tf.VERSION) >= package_version.parse("2.1.0"): - _test_tflite2_quantized_convolution( - [1, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], "SAME", "NHWC" + if package_version.parse(tf.VERSION) >= package_version.parse("2.3.0"): + _test_convolution( + [1, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], "SAME", "NHWC", quantized=True ) - _test_tflite2_quantized_convolution( - [1, 17, 17, 12], [3, 3, 12, 32], [1, 1], [2, 2], "VALID", "NHWC" + _test_convolution( + [1, 17, 17, 12], [3, 3, 12, 32], [1, 1], [2, 2], "VALID", "NHWC", quantized=True ) - _test_tflite2_quantized_convolution( - [1, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], "VALID", "NHWC" + _test_convolution( + [1, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], "VALID", "NHWC", quantized=True ) - _test_tflite2_quantized_convolution( - [1, 17, 17, 124], [1, 1, 124, 19], [1, 1], [1, 1], "SAME", "NHWC" + _test_convolution( + [1, 17, 17, 124], [1, 1, 124, 19], [1, 1], [1, 1], "SAME", "NHWC", quantized=True ) # Disable as tests are flaky - https://github.com/apache/incubator-tvm/issues/6064 @@ -2280,7 +2288,7 @@ def representative_data_gen(): for i in range(1): yield [data] - tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen) + tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen, True, True) tflite_output = run_tflite_graph(tflite_model_quant, data) tvm_output = run_tvm_graph(tflite_model_quant, data, input_name) @@ -2307,7 +2315,7 @@ def representative_data_gen(): for i in range(1): yield [data] - tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen) + tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen, True, True) tflite_output = run_tflite_graph(tflite_model_quant, data) tvm_output = run_tvm_graph(tflite_model_quant, data, input_name) @@ -2548,14 +2556,17 @@ def test_forward_padv2(): np.array([2], dtype=np.float32), ] ) - _test_padv2( - [ - np.arange(0, 256, dtype=np.uint8).reshape((1, 256)), - np.array([[1, 1], [2, 2]], dtype=np.int32), - np.array([2], dtype=np.uint8), - ], - quantized=True, - ) + # NOTE: In recent version, there is a bug in Tensorflow package for this scenario. + # Hence, it is disabled temporarily for latest TF version. + if package_version.parse(tf.VERSION) <= package_version.parse("2.1.0"): + _test_padv2( + [ + np.arange(0, 256, dtype=np.uint8).reshape((1, 256)), + np.array([[1, 1], [2, 2]], dtype=np.int32), + np.array([2], dtype=np.float32), + ], + quantized=True, + ) # Constant Values input can be scalar _test_padv2( @@ -2565,14 +2576,17 @@ def test_forward_padv2(): np.float32(2), ] ) - _test_padv2( - [ - np.arange(0, 256, dtype=np.uint8).reshape((1, 256)), - np.array([[1, 1], [2, 2]], dtype=np.int32), - np.uint8(10), - ], - quantized=True, - ) + # NOTE: In recent version, there is a bug in Tensorflow package for this scenario. + # Hence, it is disabled temporarily for latest TF version. + if package_version.parse(tf.VERSION) <= package_version.parse("2.1.0"): + _test_padv2( + [ + np.arange(0, 256, dtype=np.uint8).reshape((1, 256)), + np.array([[1, 1], [2, 2]], dtype=np.int32), + np.uint8(10), + ], + quantized=True, + ) ####################################################################### @@ -2870,37 +2884,28 @@ def test_forward_tanh(): def _test_relu(data, quantized=False): """ One iteration of ReLU """ - if quantized: - if package_version.parse(tf.VERSION) < package_version.parse("2.1.0"): - pytest.skip("Testcase requires tflite version >= 2.1.0") - data_in = tf.keras.layers.Input(shape=data.shape[1:]) - relu = tf.keras.layers.ReLU()(data_in) - keras_model = tf.keras.models.Model(inputs=data_in, outputs=relu) - input_name = data_in.name.split(":")[0] - - # To create quantized values with dynamic range of activations, needs representative dataset - def representative_data_gen(): - for i in range(1): - yield [data] - - tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen) - - tflite_output = run_tflite_graph(tflite_model_quant, data) - tvm_output = run_tvm_graph(tflite_model_quant, data, input_name) - tvm.testing.assert_allclose( - np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-5 - ) - else: - with tf.Graph().as_default(): - in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) + with tf.Graph().as_default(): + in_data = array_ops.placeholder(shape=data.shape, dtype="float32", name="in_0") + + if quantized: + inq_data = tf.quantization.fake_quant_with_min_max_args( + in_data, min=-10, max=10, name="inq_0" + ) + input_range = {"inq_0": (-10, 10)} + out = nn_ops.relu(inq_data) + out = tf.quantization.fake_quant_with_min_max_args(out, min=0, max=6, name="out") + compare_tflite_with_tvm( + data, "inq_0:0", [inq_data], [out], quantized=True, input_range=input_range + ) + else: out = nn_ops.relu(in_data) - compare_tflite_with_tvm(data, "Placeholder:0", [in_data], [out]) + compare_tflite_with_tvm(data, "in_0:0", [in_data], [out]) def test_forward_relu(): """ ReLU """ _test_relu(np.arange(6.0, dtype=np.float32).reshape((1, 6))) - _test_relu(np.arange(6.0, dtype=np.float32).reshape((1, 6)), quantized=True) + _test_relu(np.random.uniform(0, 255, (3, 6)).astype(np.uint8), quantized=True) ####################################################################### From 04c9b16f7dcf638c2520fe1dc2d1199d39f23a0a Mon Sep 17 00:00:00 2001 From: ANSHUMAN TRIPATHY Date: Wed, 28 Oct 2020 09:10:20 +0530 Subject: [PATCH 2/2] [1] Review comments handled --- tests/python/frontend/tflite/test_forward.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 58c1e3d69707..3f860a3c6580 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -2556,8 +2556,8 @@ def test_forward_padv2(): np.array([2], dtype=np.float32), ] ) - # NOTE: In recent version, there is a bug in Tensorflow package for this scenario. - # Hence, it is disabled temporarily for latest TF version. + # NOTE: In versions > 2.1.0, there is a bug in Tensorflow package for this scenario. + # Hence, it is disabled temporarily for TF version > 2.1.0 . if package_version.parse(tf.VERSION) <= package_version.parse("2.1.0"): _test_padv2( [ @@ -2576,8 +2576,8 @@ def test_forward_padv2(): np.float32(2), ] ) - # NOTE: In recent version, there is a bug in Tensorflow package for this scenario. - # Hence, it is disabled temporarily for latest TF version. + # NOTE: In versions > 2.1.0, there is a bug in Tensorflow package for this scenario. + # Hence, it is disabled temporarily for TF versions > 2.1.0. if package_version.parse(tf.VERSION) <= package_version.parse("2.1.0"): _test_padv2( [