From 8b4627197eb373320c1e3fa1fcfc8bb15e9292e7 Mon Sep 17 00:00:00 2001 From: anijain2305 Date: Thu, 18 Jun 2020 23:11:13 +0000 Subject: [PATCH 1/9] [TFLite] TFLite 2.x parser quantization support. --- python/tvm/relay/frontend/tflite.py | 109 ++++++++++++--- tests/python/frontend/tflite/test_forward.py | 137 ++++++++++++++++++- 2 files changed, 223 insertions(+), 23 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 2fc82d74a08d..b131f866c11b 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -244,10 +244,46 @@ def get_tensors(self, tensors_idx_list): qnn_params = None tflite_qnn_params = tensor.Quantization() if tflite_qnn_params is not None: - scale = float(tflite_qnn_params.ScaleAsNumpy()) - zero_point = int(tflite_qnn_params.ZeroPointAsNumpy()) + # Params might be per-tensor or per-axis quantized. For per-tensor, scale and zero + # points are scalar. For per-axis, scale and zero points are tensors. But as per + # TFLite quantization spec, the restrictions on ops suggest that for per-axis, even + # if zero point is a tensor - all the zero points are identical. More infomration + # here - https://www.tensorflow.org/lite/performance/quantization_spec + + tflite_scale = tflite_qnn_params.ScaleAsNumpy() + tflite_zero_point = tflite_qnn_params.ZeroPointAsNumpy() + is_qnn_params_valid = True + + # Handle Per-axis and per-tensor cases + if isinstance(tflite_scale, np.ndarray): + assert isinstance(tflite_zero_point, np.ndarray) + + # Tensor - Per-axis quantization + if tflite_scale.shape != (1,) and tflite_zero_point.shape != (1,): + scale = tflite_scale + # Ensure that all zero points are identical + zero_point = tflite_zero_point + assert all(x == zero_point[0] for x in zero_point) + zero_point = int(zero_point[0]) + + # Scalar - Per-tensor quantization + elif tflite_scale.shape == (1,) and tflite_zero_point.shape == (1,): + scale = float(tflite_scale[0]) + zero_point = int(tflite_zero_point[0]) + + else: + raise NotImplementedError("Quantized type {} not supported" + .format(type(tflite_scale))) + elif tflite_scale == 0 and tflite_zero_point == 0: + # Handle corner case for ops like quantized reshape whose second operand (shape) + # has zero scale and zero zero point. This is not used. + is_qnn_params_valid = False + else: + raise NotImplementedError("Quantized type {} not supported" + .format(type(tflite_scale))) + # Check that the scale and zero points are valid. - if scale != 0 or zero_point != 0: + if is_qnn_params_valid: qnn_params = dict() qnn_params['scale'] = relay.const(scale, 'float32') qnn_params['zero_point'] = relay.const(zero_point, 'int32') @@ -263,21 +299,25 @@ def get_tensor_value(self, tensor_wrapper): except ImportError: raise ImportError("The tflite package must be installed") + data = tensor_wrapper.buffer.DataAsNumpy() + shape = tensor_wrapper.tensor.ShapeAsNumpy() + + # Set shape to 1 if the data is a scalar type + if data.shape == (1,) and isinstance(shape, int) and shape == 0: + shape = (1,) + + if tensor_wrapper.tensor.Type() == TensorType.INT8: + return np.frombuffer(data, dtype=np.int8).reshape(shape) if tensor_wrapper.tensor.Type() == TensorType.UINT8: - return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.uint8).reshape( - tensor_wrapper.tensor.ShapeAsNumpy()) - if tensor_wrapper.tensor.Type() == TensorType.FLOAT32: - return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.float32).reshape( - tensor_wrapper.tensor.ShapeAsNumpy()) + return np.frombuffer(data, dtype=np.uint8).reshape(shape) if tensor_wrapper.tensor.Type() == TensorType.INT32: - return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.int32).reshape( - tensor_wrapper.tensor.ShapeAsNumpy()) + return np.frombuffer(data, dtype=np.int32).reshape(shape) if tensor_wrapper.tensor.Type() == TensorType.INT64: - return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.int64).reshape( - tensor_wrapper.tensor.ShapeAsNumpy()) + return np.frombuffer(data, dtype=np.int64).reshape(shape) + if tensor_wrapper.tensor.Type() == TensorType.FLOAT32: + return np.frombuffer(data, dtype=np.float32).reshape(shape) if tensor_wrapper.tensor.Type() == TensorType.BOOL: - return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.bool_).reshape( - tensor_wrapper.tensor.ShapeAsNumpy()) + return np.frombuffer(data, dtype=np.bool).reshape(shape) raise NotImplementedError("Tensor type {} is currently not supported" .format(str(tensor_wrapper.tensor.Type()))) @@ -1606,7 +1646,7 @@ def convert_fully_connected(self, op): # weight tensor type should be UINT8 (quantization) or FLOAT32 weight_tensor_type = weight_tensor.tensor.Type() - assert weight_tensor_type in (TensorType.UINT8, TensorType.FLOAT32) + assert weight_tensor_type in (TensorType.INT8, TensorType.UINT8, TensorType.FLOAT32) weight_tensor_type_str = self.get_tensor_type_str(weight_tensor_type) if self.has_expr(weight_tensor.tensor_idx): @@ -1797,7 +1837,7 @@ def convert_conv(self, op, conv_type): # weight tensor type should be UINT8 (quantization) or FLOAT32 weight_tensor_type = weight_tensor.tensor.Type() - assert weight_tensor_type in (TensorType.UINT8, TensorType.FLOAT32) + assert weight_tensor_type in (TensorType.INT8, TensorType.UINT8, TensorType.FLOAT32) weight_tensor_type_str = self.get_tensor_type_str(weight_tensor_type) in_expr = self.get_expr(input_tensor_idx) @@ -1856,9 +1896,15 @@ def convert_conv(self, op, conv_type): if output_tensor.qnn_params: # Calculate the intermediate scale and zero point of the int32 output. data_scale = input_tensor.qnn_params['scale'] - weight_scale = weight_tensor.qnn_params['scale'] data_scale_val = get_scalar_from_constant(data_scale) - weight_scale_val = get_scalar_from_constant(weight_scale) + + weight_scale = weight_tensor.qnn_params['scale'] + # If weight scale is scalar, it is per-tensor quantization + if isinstance(weight_scale, float): + weight_scale_val = get_scalar_from_constant(weight_scale) + else: + weight_scale_val = get_vector_from_constant(weight_scale) + new_input_scale_val = data_scale_val * weight_scale_val new_input_scale = relay.const(new_input_scale_val, 'float32') new_input_zero_point = relay.const(0, 'int32') @@ -1869,7 +1915,8 @@ def convert_conv(self, op, conv_type): input_zero_point=new_input_zero_point, output_scale=output_tensor.qnn_params['scale'], output_zero_point=output_tensor.qnn_params['zero_point'], - out_dtype=output_tensor_type_str) + out_dtype=output_tensor_type_str, + axis=3) # Call activation function output_scale_val = get_scalar_from_constant(output_tensor.qnn_params['scale']) @@ -2594,17 +2641,27 @@ def convert_quantize(self, op): input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 1, "input tensors length should be 1" input_tensor = input_tensors[0] + input_tensor_type_str = self.get_tensor_type_str(input_tensor.tensor.Type()) in_expr = self.get_expr(input_tensor.tensor_idx) output_tensors = self.get_output_tensors(op) assert len(output_tensors) == 1, "output tensors length should be 1" output_tensor = output_tensors[0] + output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type()) # The output must be quantized assert output_tensor.qnn_params - # Quantize the input - out = self.quantize(in_expr, output_tensor) + # TFLite Quantize op can also act as Requantize op + if input_tensor_type_str == "float32": + out = self.quantize(in_expr, output_tensor) + else: + out = _qnn.op.requantize(in_expr, + input_scale=input_tensor.qnn_params['scale'], + input_zero_point=input_tensor.qnn_params['zero_point'], + output_scale=output_tensor.qnn_params['scale'], + output_zero_point=output_tensor.qnn_params['zero_point'], + out_dtype=output_tensor_type_str) return out def convert_dequantize(self, op): @@ -2734,7 +2791,6 @@ def get_tensor_expr(self, tensor): else: type_str = self.get_tensor_type_str(tensor.tensor.Type()) expr = self.exp_tab.new_const(self.get_tensor_value(tensor), dtype=type_str) - return expr @@ -2747,6 +2803,15 @@ def get_scalar_from_constant(expr): "value must be float32/int32" return np.asscalar(value) +def get_vector_from_constant(expr): + """ Returns scalar value from Relay constant scalar. """ + assert isinstance(expr, _expr.Constant) + value = expr.data.asnumpy() + assert value.dtype == np.dtype(np.int32) or value.dtype == np.dtype(np.float32), \ + "value must be float32/int32" + return value + + def build_str_map(obj): """Build string map of TFLite enum int value diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 166eb2740edb..249e5f0bfa47 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -39,6 +39,7 @@ from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import nn_impl from tensorflow.python.ops import variables +import tensorflow_hub as hub try: from tensorflow import lite as interpreter_wrapper except ImportError: @@ -73,6 +74,28 @@ def get_real_image(im_height, im_width): data = np.reshape(x, (1, im_height, im_width, 3)) return data + +def pre_processed_image(height, width): + repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/' + img_name = 'elephant-299.jpg' + image_url = os.path.join(repo_base, img_name) + img_path = download_testdata(image_url, img_name, module='data') + image = tf.io.read_file(img_path) + image = tf.image.decode_jpeg(image, channels=3) + with tf.name_scope('eval_image'): + if image.dtype != tf.float32: + image = tf.image.convert_image_dtype(image, dtype=tf.float32) + image = tf.image.central_crop(image, central_fraction=0.875) + # Resize the image to the specified height and width. + image = tf.expand_dims(image, 0) + image = tf.image.resize(image, [height, width], + align_corners=False) + image = tf.image.resize(image, [height, width]) + image = tf.squeeze(image, [0]) + image = tf.expand_dims(image, axis=0) + return image + + def get_real_image_object_detection(im_height, im_width): repo_base = 'https://github.com/dmlc/web-data/raw/master/gluoncv/detection/' img_name = 'street_small.jpg' @@ -1707,7 +1730,6 @@ def representative_data_gen(): # Convert the model to TensorFlow Lite format tflite_model_quant = converter.convert() - tflite_output = run_tflite_graph(tflite_model_quant, data) tvm_output = run_tvm_graph(tflite_model_quant, data, 'input_1') tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), @@ -2471,6 +2493,112 @@ def test_forward_qnn_mobilenet_v3_net(): tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels) + +def _quantize_tf_hub_keras_model(url, height, width): + keras_model = tf.keras.Sequential([hub.KerasLayer(url, output_shape=[1001])]) + data = pre_processed_image(height, width) + + # Set the input shapes of the keras model + keras_model._set_inputs(data) + + # Get the converter + converter = interpreter_wrapper.TFLiteConverter.from_keras_model(keras_model) + + # To create quantized values with dynamic range of activations, needs representative dataset + def representative_data_gen(): + for i in range(1): + yield [data] + + converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE] + converter.representative_dataset = representative_data_gen + return converter.convert() + + +def test_forward_tflite2_qnn_resnet50(): + """Test the Quantized TFLite version 2.1.0 Resnet50 model.""" + if package_version.parse(tf.VERSION) >= package_version.parse('2.1.0'): + # Quantize the model + url = "https://tfhub.dev/tensorflow/resnet_50/classification/1" + tflite_model_buf = _quantize_tf_hub_keras_model(url, 224, 224) + data = pre_processed_image(224, 224) + + tflite_output = run_tflite_graph(tflite_model_buf, data) + tflite_predictions = np.squeeze(tflite_output) + tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1] + tvm_output = run_tvm_graph(tflite_model_buf, np.array(data), 'input_1') + tvm_predictions = np.squeeze(tvm_output) + tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1] + tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels) + + +def test_forward_tflite2_qnn_inception_v1(): + """Test the Quantized TFLite version 2.1.0 Inception V1 model.""" + if package_version.parse(tf.VERSION) >= package_version.parse('2.1.0'): + # Quantize the model + url = "https://tfhub.dev/google/imagenet/inception_v1/classification/4" + tflite_model_buf = _quantize_tf_hub_keras_model(url, 224, 224) + data = pre_processed_image(224, 224) + + tflite_output = run_tflite_graph(tflite_model_buf, data) + tflite_predictions = np.squeeze(tflite_output) + tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1] + tvm_output = run_tvm_graph(tflite_model_buf, np.array(data), 'input_1') + tvm_predictions = np.squeeze(tvm_output) + tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1] + tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels) + + +def test_forward_tflite2_qnn_inception_v3(): + """Test the Quantized TFLite version 2.1.0 Inception V3 model.""" + if package_version.parse(tf.VERSION) >= package_version.parse('2.1.0'): + # Quantize the model + url = "https://tfhub.dev/google/imagenet/inception_v3/classification/4" + tflite_model_buf = _quantize_tf_hub_keras_model(url, 299, 299) + data = pre_processed_image(299, 299) + + tflite_output = run_tflite_graph(tflite_model_buf, data) + tflite_predictions = np.squeeze(tflite_output) + tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1] + tvm_output = run_tvm_graph(tflite_model_buf, np.array(data), 'input_1') + tvm_predictions = np.squeeze(tvm_output) + tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1] + tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels) + + +def test_forward_tflite2_qnn_mobilenet_v1(): + """Test the Quantized TFLite version 2.1.0 Mobilenet V1 model.""" + if package_version.parse(tf.VERSION) >= package_version.parse('2.1.0'): + # Quantize the model + url = "https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/4" + tflite_model_buf = _quantize_tf_hub_keras_model(url, 224, 224) + data = pre_processed_image(224, 224) + + tflite_output = run_tflite_graph(tflite_model_buf, data) + tflite_predictions = np.squeeze(tflite_output) + tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1] + tvm_output = run_tvm_graph(tflite_model_buf, np.array(data), 'input_1') + tvm_predictions = np.squeeze(tvm_output) + tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1] + tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels) + + +def test_forward_tflite2_qnn_mobilenet_v2(): + """Test the Quantized TFLite version 2.1.0 Mobilenet V2 model.""" + if package_version.parse(tf.VERSION) >= package_version.parse('2.1.0'): + # Quantize the model + url = "https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/classification/4" + tflite_model_buf = _quantize_tf_hub_keras_model(url, 224, 224) + data = pre_processed_image(224, 224) + + tflite_output = run_tflite_graph(tflite_model_buf, data) + tflite_predictions = np.squeeze(tflite_output) + tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1] + tvm_output = run_tvm_graph(tflite_model_buf, np.array(data), 'input_1') + tvm_predictions = np.squeeze(tvm_output) + tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1] + tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels) + + ####################################################################### # Quantized SSD Mobilenet # ----------------------- @@ -2689,3 +2817,10 @@ def test_forward_mediapipe_hand_landmark(): #with Tflite 1.15.2 test_forward_qnn_mobilenet_v3_net() test_forward_qnn_coco_ssd_mobilenet_v1() + + # TFLite 2.1.0 quantized tests + test_forward_tflite2_qnn_resnet50() + test_forward_tflite2_qnn_inception_v1() + test_forward_tflite2_qnn_inception_v3() + test_forward_tflite2_qnn_mobilenet_v1() + test_forward_tflite2_qnn_mobilenet_v2() From 7f52fc8d9d520f36f99e8426c807bbd365bb3fde Mon Sep 17 00:00:00 2001 From: anijain2305 Date: Thu, 25 Jun 2020 07:18:27 +0000 Subject: [PATCH 2/9] Address comments. Fix a bug for depthwise conv --- python/tvm/relay/frontend/tflite.py | 44 +++++++++++--------- src/relay/qnn/op/convolution.cc | 16 +++++-- tests/python/frontend/tflite/test_forward.py | 5 +-- 3 files changed, 39 insertions(+), 26 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index b131f866c11b..ed5a4c3c1fc4 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -244,11 +244,13 @@ def get_tensors(self, tensors_idx_list): qnn_params = None tflite_qnn_params = tensor.Quantization() if tflite_qnn_params is not None: - # Params might be per-tensor or per-axis quantized. For per-tensor, scale and zero - # points are scalar. For per-axis, scale and zero points are tensors. But as per - # TFLite quantization spec, the restrictions on ops suggest that for per-axis, even - # if zero point is a tensor - all the zero points are identical. More infomration - # here - https://www.tensorflow.org/lite/performance/quantization_spec + # TFLite supports both per-tensor and per-axis (aka channel) quantization. For + # per-tensor quantization, scale and zero points are scalar values. For per-axis + # quantization, scale and zero points for the weights are tensors (activations are + # per-tensor quantized). However, the TFLite quantization spec puts restrictions on + # zero points for per-axis quantization. Specifically, the zero point is a tensor + # but all values are 0. More information can be found here - + # https://www.tensorflow.org/lite/performance/quantization_spec tflite_scale = tflite_qnn_params.ScaleAsNumpy() tflite_zero_point = tflite_qnn_params.ZeroPointAsNumpy() @@ -259,15 +261,18 @@ def get_tensors(self, tensors_idx_list): assert isinstance(tflite_zero_point, np.ndarray) # Tensor - Per-axis quantization - if tflite_scale.shape != (1,) and tflite_zero_point.shape != (1,): + if tflite_scale.size != 1 and tflite_zero_point.size != 1: scale = tflite_scale - # Ensure that all zero points are identical + # Ensure that all zero points are zeros zero_point = tflite_zero_point - assert all(x == zero_point[0] for x in zero_point) + if not all(x == 0 for x in zero_point): + raise tvm.error.OpAttributeInvalid(\ + "TFLite per-axis quantization restricts all zero points to be" + + " 0, but a non-zero value is observed") zero_point = int(zero_point[0]) # Scalar - Per-tensor quantization - elif tflite_scale.shape == (1,) and tflite_zero_point.shape == (1,): + elif tflite_scale.size == 1 and tflite_zero_point.size == 1: scale = float(tflite_scale[0]) zero_point = int(tflite_zero_point[0]) @@ -299,11 +304,15 @@ def get_tensor_value(self, tensor_wrapper): except ImportError: raise ImportError("The tflite package must be installed") + # Read the data from the buffer. Also extract the shape. + # The shape is used later to reshape the data. data = tensor_wrapper.buffer.DataAsNumpy() shape = tensor_wrapper.tensor.ShapeAsNumpy() - # Set shape to 1 if the data is a scalar type - if data.shape == (1,) and isinstance(shape, int) and shape == 0: + # When TFLite buffer is of size 1 (scalar), then TFLite tensor shape is set to 0. + # Therefore, we set the shape to 1 for numpy reshape to work. + Set shape to 1 if the data is a scalar type + if data.size == 1 and isinstance(shape, int) and shape == 0: shape = (1,) if tensor_wrapper.tensor.Type() == TensorType.INT8: @@ -1644,7 +1653,7 @@ def convert_fully_connected(self, op): fully_connected_options.Init(op_options.Bytes, op_options.Pos) fused_activation_fn = fully_connected_options.FusedActivationFunction() - # weight tensor type should be UINT8 (quantization) or FLOAT32 + # weight tensor type should be INT8/UINT8 (quantization) or FLOAT32 weight_tensor_type = weight_tensor.tensor.Type() assert weight_tensor_type in (TensorType.INT8, TensorType.UINT8, TensorType.FLOAT32) weight_tensor_type_str = self.get_tensor_type_str(weight_tensor_type) @@ -1835,7 +1844,7 @@ def convert_conv(self, op, conv_type): params['channels'] = int(output_channels) params['kernel_layout'] = 'HWIO' - # weight tensor type should be UINT8 (quantization) or FLOAT32 + # weight tensor type should be INT8/UINT8 (quantization) or FLOAT32 weight_tensor_type = weight_tensor.tensor.Type() assert weight_tensor_type in (TensorType.INT8, TensorType.UINT8, TensorType.FLOAT32) weight_tensor_type_str = self.get_tensor_type_str(weight_tensor_type) @@ -1903,7 +1912,7 @@ def convert_conv(self, op, conv_type): if isinstance(weight_scale, float): weight_scale_val = get_scalar_from_constant(weight_scale) else: - weight_scale_val = get_vector_from_constant(weight_scale) + weight_scale_val = get_tensor_from_constant(weight_scale) new_input_scale_val = data_scale_val * weight_scale_val new_input_scale = relay.const(new_input_scale_val, 'float32') @@ -1929,7 +1938,6 @@ def convert_conv(self, op, conv_type): dtype=output_tensor_type_str) else: out = self.convert_fused_activation_function(out, fused_activation_fn) - return out def convert_split(self, op): @@ -2803,16 +2811,14 @@ def get_scalar_from_constant(expr): "value must be float32/int32" return np.asscalar(value) -def get_vector_from_constant(expr): - """ Returns scalar value from Relay constant scalar. """ +def get_tensor_from_constant(expr): + """ Returns tensor of values from Relay constant node. """ assert isinstance(expr, _expr.Constant) value = expr.data.asnumpy() assert value.dtype == np.dtype(np.int32) or value.dtype == np.dtype(np.float32), \ "value must be float32/int32" return value - - def build_str_map(obj): """Build string map of TFLite enum int value diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc index 9412ab4393c5..5d2e360e0951 100644 --- a/src/relay/qnn/op/convolution.cc +++ b/src/relay/qnn/op/convolution.cc @@ -61,9 +61,19 @@ bool QnnConv2DRel(const Array& types, int num_inputs, const Attrs& attrs, CHECK(IsScalarType(types[3], DataType::Int(32))); // kernel_zero_point CHECK(IsScalarType(types[4], DataType::Float(32))); // input_scale // Kernel scale can be a vector of length output_channels or a scalar. - size_t axis = param->kernel_layout.find('O'); - CHECK(axis != std::string::npos) << "Kernel layout attribute is not defined"; - AssignType(types[5], DataType::Float(32), weight->shape[axis], reporter); // kernel scale + if (param->groups == 1) { + size_t axis = param->kernel_layout.find('O'); + CHECK(axis != std::string::npos) << "Kernel layout attribute is not defined"; + AssignType(types[5], DataType::Float(32), weight->shape[axis], reporter); // kernel scale + } else { + // Here, total number of output channels depend on depth multiplier. + size_t o_axis = param->kernel_layout.find('O'); + size_t i_axis = param->kernel_layout.find('I'); + CHECK(o_axis != std::string::npos || i_axis != std::string::npos) + << "Kernel layout attribute is not defined"; + AssignType(types[5], DataType::Float(32), weight->shape[i_axis] * weight->shape[o_axis], + reporter); // kernel scale + } // Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay // Conv2D infer type function. diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 249e5f0bfa47..56a06803a9f2 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -87,11 +87,8 @@ def pre_processed_image(height, width): image = tf.image.convert_image_dtype(image, dtype=tf.float32) image = tf.image.central_crop(image, central_fraction=0.875) # Resize the image to the specified height and width. - image = tf.expand_dims(image, 0) image = tf.image.resize(image, [height, width], align_corners=False) - image = tf.image.resize(image, [height, width]) - image = tf.squeeze(image, [0]) image = tf.expand_dims(image, axis=0) return image @@ -2493,8 +2490,8 @@ def test_forward_qnn_mobilenet_v3_net(): tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels) - def _quantize_tf_hub_keras_model(url, height, width): + """Utility function to quantize a Keras model using TFLite converter.""" keras_model = tf.keras.Sequential([hub.KerasLayer(url, output_shape=[1001])]) data = pre_processed_image(height, width) From c7a198d7ce34842caf3a49c5d10ff4dcd63cb453 Mon Sep 17 00:00:00 2001 From: anijain2305 Date: Thu, 25 Jun 2020 20:33:18 +0000 Subject: [PATCH 3/9] Added tests for relu, conv, quantize. Address comments. --- python/tvm/relay/frontend/tflite.py | 68 ++++-- tests/python/frontend/tflite/test_forward.py | 217 ++++++++++++++++--- 2 files changed, 232 insertions(+), 53 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index ed5a4c3c1fc4..58adf8d9f15b 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -265,7 +265,7 @@ def get_tensors(self, tensors_idx_list): scale = tflite_scale # Ensure that all zero points are zeros zero_point = tflite_zero_point - if not all(x == 0 for x in zero_point): + if not np.all(zero_point == 0): raise tvm.error.OpAttributeInvalid(\ "TFLite per-axis quantization restricts all zero points to be" + " 0, but a non-zero value is observed") @@ -277,8 +277,9 @@ def get_tensors(self, tensors_idx_list): zero_point = int(tflite_zero_point[0]) else: - raise NotImplementedError("Quantized type {} not supported" - .format(type(tflite_scale))) + raise NotImplementedError(\ + "Quantized type {} (scale) and {} (zero point) not supported" + .format(type(tflite_scale), type(tflite_zero_point))) elif tflite_scale == 0 and tflite_zero_point == 0: # Handle corner case for ops like quantized reshape whose second operand (shape) # has zero scale and zero zero point. This is not used. @@ -310,8 +311,8 @@ def get_tensor_value(self, tensor_wrapper): shape = tensor_wrapper.tensor.ShapeAsNumpy() # When TFLite buffer is of size 1 (scalar), then TFLite tensor shape is set to 0. - # Therefore, we set the shape to 1 for numpy reshape to work. - Set shape to 1 if the data is a scalar type + # Therefore, we set the shape to 1 for numpy reshape to work. Set shape to 1 if the data is + # a scalar type if data.size == 1 and isinstance(shape, int) and shape == 0: shape = (1,) @@ -700,12 +701,43 @@ def convert_shape(self, op): def convert_relu(self, op): """Convert TFLite ReLU""" + try: + from tflite.ActivationFunctionType import ActivationFunctionType + except ImportError: + raise ImportError("The tflite package must be installed") + input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 1, "input tensors length should be 1" - input_tensor = input_tensors[0] in_expr = self.get_expr(input_tensor.tensor_idx) - out = _op.nn.relu(in_expr) + + output_tensors = self.get_output_tensors(op) + assert len(output_tensors) == 1, "output tensors length should be 1" + output_tensor = output_tensors[0] + + if input_tensor.qnn_params: + # Quantize a float value to an quantized integer value + scale_val = get_scalar_from_constant(input_tensor.qnn_params['scale']) + zero_point_val = get_scalar_from_constant(input_tensor.qnn_params['zero_point']) + + output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type()) + out = self.convert_qnn_fused_activation_function(\ + expr=in_expr, + fused_activation_fn=ActivationFunctionType.RELU, + scale=scale_val, + zero_point=zero_point_val, + dtype=output_tensor_type_str) + else: + out = _op.clip(in_expr, a_min=0, a_max=6) + + if output_tensor.qnn_params: + output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type()) + out = _qnn.op.requantize(out, + input_scale=input_tensor.qnn_params['scale'], + input_zero_point=input_tensor.qnn_params['zero_point'], + output_scale=output_tensor.qnn_params['scale'], + output_zero_point=output_tensor.qnn_params['zero_point'], + out_dtype=output_tensor_type_str) return out @@ -741,6 +773,11 @@ def _hard_swish(data): def convert_relu6(self, op): """Convert TFLite ReLU6""" + try: + from tflite.ActivationFunctionType import ActivationFunctionType + except ImportError: + raise ImportError("The tflite package must be installed") + input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 1, "input tensors length should be 1" input_tensor = input_tensors[0] @@ -754,17 +791,14 @@ def convert_relu6(self, op): # Quantize a float value to an quantized integer value scale_val = get_scalar_from_constant(input_tensor.qnn_params['scale']) zero_point_val = get_scalar_from_constant(input_tensor.qnn_params['zero_point']) - quantize = lambda x: float(int(round(x / scale_val)) + zero_point_val) - # Get min/max of the input dtype. This will be used to ensure that - # clip a_min/a_max are not beyond the dtype range. - input_tensor_type_str = self.get_tensor_type_str(input_tensor.tensor.Type()) - qmin = float(tvm.tir.op.min_value(input_tensor_type_str).value) - qmax = float(tvm.tir.op.max_value(input_tensor_type_str).value) - - out = _op.clip(in_expr, - a_min=max(qmin, quantize(0)), - a_max=min(qmax, quantize(6.0))) + output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type()) + out = self.convert_qnn_fused_activation_function(\ + expr=in_expr, + fused_activation_fn=ActivationFunctionType.RELU6, + scale=scale_val, + zero_point=zero_point_val, + dtype=output_tensor_type_str) else: out = _op.clip(in_expr, a_min=0, a_max=6) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 56a06803a9f2..a07696624b18 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -737,6 +737,86 @@ def test_forward_l2_pool2d(): # Convolution # ----------- + +def _test_tflite2_quantized_convolution(input_shape, kernel_shape, + dilations, strides, padding, data_format): + """ One iteration of TFLite2 quantized convolution with given shapes and attributes """ + data_format = "channels_last" if "NHWC" else "channels_first" + data = np.random.uniform(0, 1, input_shape).astype('float32') + kernel = np.random.uniform(0, 1, kernel_shape).astype('float32') + + data_in = tf.keras.layers.Input(shape=data.shape[1:]) + conv = tf.keras.layers.Conv2D(filters=kernel_shape[3], + kernel_size=(kernel_shape[0], kernel_shape[1]), + strides=strides, + padding=padding, + data_format=data_format, + activation='relu', + use_bias=False)(data_in) + keras_model = tf.keras.models.Model(data_in, conv) + keras_model.layers[1].set_weights([kernel]) + + converter = interpreter_wrapper.TFLiteConverter.from_keras_model(keras_model) + + # To create quantized values with dynamic range of activations, needs representative dataset + def representative_data_gen(): + for i in range(1): + yield [data] + + converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE] + converter.representative_dataset = representative_data_gen + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.uint8 + converter.inference_output_type = tf.uint8 + + # Convert the model to TensorFlow Lite format + tflite_model_quant = converter.convert() + + tflite_output = run_tflite_graph(tflite_model_quant, data) + tvm_output = run_tvm_graph(tflite_model_quant, data, data_in.name.replace(":0","")) + tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), + rtol=1e-2, atol=1e-2) + + +def _test_tflite2_quantized_depthwise_convolution(input_shape, kernel_shape, + dilations, strides, padding, data_format, depth_multiplier): + """One iteration of TFLite2 quantized depthwise convolution with given shapes and attributes""" + data_format = "channels_last" if "NHWC" else "channels_first" + data = np.random.uniform(0, 1, input_shape).astype('float32') + kernel = np.random.uniform(0, 1, kernel_shape).astype('float32') + + data_in = tf.keras.layers.Input(shape=data.shape[1:]) + conv = tf.keras.layers.DepthwiseConv2D(kernel_size=(kernel_shape[0], kernel_shape[1]), + strides=strides, + padding=padding, + data_format=data_format, + activation='relu', + use_bias=False, + depth_multiplier=depth_multiplier)(data_in) + keras_model = tf.keras.models.Model(data_in, conv) + keras_model.layers[1].set_weights([kernel]) + + converter = interpreter_wrapper.TFLiteConverter.from_keras_model(keras_model) + + # To create quantized values with dynamic range of activations, needs representative dataset + def representative_data_gen(): + for i in range(1): + yield [data] + + converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE] + converter.representative_dataset = representative_data_gen + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.uint8 + converter.inference_output_type = tf.uint8 + + # Convert the model to TensorFlow Lite format + tflite_model_quant = converter.convert() + tflite_output = run_tflite_graph(tflite_model_quant, data) + tvm_output = run_tvm_graph(tflite_model_quant, data, data_in.name.replace(":0","")) + tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), + rtol=1e-2, atol=1e-2) + + def _test_convolution(tensor_in_sizes, filter_in_sizes, dilations, strides, padding, data_format, is_depthwise=False, quantized=False): @@ -777,24 +857,38 @@ def _test_convolution(tensor_in_sizes, filter_in_sizes, data_format=data_format) if quantized: - # For now only quantized conv2d is supported - assert not is_depthwise - - # Quantized the inputs and feed them to the convolution - inq_data = tf.quantization.fake_quant_with_min_max_args(in_data, min=-100, max=100, name='inq_data') - inq_filter = tf.quantization.fake_quant_with_min_max_args(in_filter, min=-100, max=100, name='inq_filter') - out = nn_ops.conv2d(inq_data, - inq_filter, - strides=strides, - padding=padding, - data_format=data_format) - out = tf.quantization.fake_quant_with_min_max_args(out, min=-200, max=200, name="out") - - # Set the input quantization range - input_range = {'in_data': (-100, 100)} if quantized else None - - # Compare - compare_tflite_with_tvm(data_array, 'in_data', [in_data], [out], quantized=quantized, input_range=input_range) + if is_depthwise: + # Quantized the inputs and feed them to the convolution + inq_data = tf.quantization.fake_quant_with_min_max_args(in_data, min=-100, max=100, name='inq_data') + inq_filter = tf.quantization.fake_quant_with_min_max_args(in_filter, min=-100, max=100, name='inq_filter') + out = nn_ops.depthwise_conv2d_native(inq_data, + inq_filter, + strides=strides, + padding=padding, + data_format=data_format) + out = tf.quantization.fake_quant_with_min_max_args(out, min=-200, max=200, name="out") + + # Set the input quantization range + input_range = {'in_data': (-100, 100)} if quantized else None + + # Compare + compare_tflite_with_tvm(data_array, 'in_data', [in_data], [out], quantized=quantized, input_range=input_range) + else: + # Quantized the inputs and feed them to the convolution + inq_data = tf.quantization.fake_quant_with_min_max_args(in_data, min=-100, max=100, name='inq_data') + inq_filter = tf.quantization.fake_quant_with_min_max_args(in_filter, min=-100, max=100, name='inq_filter') + out = nn_ops.conv2d(inq_data, + inq_filter, + strides=strides, + padding=padding, + data_format=data_format) + out = tf.quantization.fake_quant_with_min_max_args(out, min=-200, max=200, name="out") + + # Set the input quantization range + input_range = {'in_data': (-100, 100)} if quantized else None + + # Compare + compare_tflite_with_tvm(data_array, 'in_data', [in_data], [out], quantized=quantized, input_range=input_range) else: data_array = np.reshape(data_array, tensor_in_sizes).astype('float32') compare_tflite_with_tvm(data_array, 'in_data', [in_data], [out]) @@ -807,14 +901,30 @@ def test_forward_convolution(): _test_convolution([4, 17, 17, 124], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME', 'NHWC', quantized=quantized) _test_convolution([4, 17, 17, 12], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID', 'NHWC', quantized=quantized) - # depthwise convolution - _test_convolution([4, 8, 8, 176], [1, 1, 176, 1], [1, 1], [1, 1], 'SAME', 'NHWC', True) - _test_convolution([4, 17, 17, 19], [3, 3, 19, 1], [1, 1], [2, 2], 'VALID', 'NHWC', True) - _test_convolution([4, 17, 17, 124], [1, 1, 124, 1], [1, 1], [1, 1], 'SAME', 'NHWC', True) - _test_convolution([4, 17, 17, 12], [3, 3, 12, 1], [1, 1], [2, 2], 'VALID', 'NHWC', True) - _test_convolution([4, 17, 17, 12], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NHWC', True) - # dephtwise convolution with single input channel - _test_convolution([1, 76, 64, 1], [9, 5, 1, 96], [1, 1], [1, 1], 'SAME', 'NHWC', True) + # depthwise convolution + _test_convolution([4, 8, 8, 176], [1, 1, 176, 1], [1, 1], [1, 1], 'SAME', 'NHWC', True, quantized=quantized) + _test_convolution([4, 17, 17, 19], [3, 3, 19, 1], [1, 1], [2, 2], 'VALID', 'NHWC', True, quantized=quantized) + _test_convolution([4, 17, 17, 124], [1, 1, 124, 1], [1, 1], [1, 1], 'SAME', 'NHWC', True, quantized=quantized) + _test_convolution([4, 17, 17, 12], [3, 3, 12, 1], [1, 1], [2, 2], 'VALID', 'NHWC', True, quantized=quantized) + _test_convolution([4, 17, 17, 12], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NHWC', True, quantized=quantized) + # dephtwise convolution with single input channel + _test_convolution([1, 76, 64, 1], [9, 5, 1, 96], [1, 1], [1, 1], 'SAME', 'NHWC', True, quantized=quantized) + + # TFLite2 quantized convolution testing + if package_version.parse(tf.VERSION) >= package_version.parse('2.1.0'): + _test_tflite2_quantized_convolution([1, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', 'NHWC') + _test_tflite2_quantized_convolution([1, 17, 17, 12], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID', 'NHWC') + _test_tflite2_quantized_convolution([1, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', 'NHWC') + _test_tflite2_quantized_convolution([1, 17, 17, 124], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME', 'NHWC') + + # depthwise convolution + _test_tflite2_quantized_depthwise_convolution([1, 8, 8, 128], [1, 1, 128, 1], [1, 1], [1, 1], + 'SAME', 'NHWC', 1) + _test_tflite2_quantized_depthwise_convolution([1, 17, 17, 12], [3, 3, 12, 1], [1, 1], [2, 2], + 'VALID', 'NHWC', 1) + _test_tflite2_quantized_depthwise_convolution([1, 24, 24, 3], [7, 7, 3, 8], [1, 1], [2, 2], + 'SAME', 'NHWC', 8) + ####################################################################### @@ -1706,10 +1816,14 @@ def test_forward_squeeze(): def _test_quantize_dequantize(data): """ One iteration of quantize and dequantize """ - # Define a dummy model + # Keras model to force TFLite converter to insert 2 TFLite quantize ops. + # First TFLite quantize op converts float32 tensor to int8 tensor - Qnn quantize. + # Second TLite quantize op converts int8 tensor to int8 tensor - Qnn requantize. data_in = tf.keras.layers.Input(shape=data.shape[1:]) - act_func = tf.keras.layers.Activation('linear') - keras_model = tf.keras.models.Model(data_in, act_func(data_in)) + relu = tf.keras.layers.ReLU()(data_in) + add = tf.keras.layers.Add()([data_in, relu]) + concat = tf.keras.layers.Concatenate(axis=0)([relu, add]) + keras_model = tf.keras.models.Model(inputs=data_in, outputs=concat) # Load the model converter = interpreter_wrapper.TFLiteConverter.from_keras_model(keras_model) @@ -1727,16 +1841,17 @@ def representative_data_gen(): # Convert the model to TensorFlow Lite format tflite_model_quant = converter.convert() + tflite_output = run_tflite_graph(tflite_model_quant, data) tvm_output = run_tvm_graph(tflite_model_quant, data, 'input_1') tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), - rtol=1e-5, atol=1e-5) + rtol=1e-5, atol=1e-2) def test_forward_quantize_dequantize(): """ Quantize Dequantize """ data = np.random.uniform(0, 1, (1, 4, 4, 3)).astype("float32") - if package_version.parse(tf.VERSION) >= package_version.parse('2.0.0'): + if package_version.parse(tf.VERSION) >= package_version.parse('2.1.0'): _test_quantize_dequantize(data) @@ -1964,16 +2079,46 @@ def test_forward_tanh(): # ReLu # ---- -def _test_relu(data): +def _test_relu(data, quantized=False): """ One iteration of ReLU """ - with tf.Graph().as_default(): - in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) - out = nn_ops.relu(in_data) - compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out]) + + if quantized: + if package_version.parse(tf.VERSION) < package_version.parse('2.1.0'): + pytest.skip("Testcase requires tflite version >= 2.1.0") + data_in = tf.keras.layers.Input(shape=data.shape[1:]) + relu = tf.keras.layers.ReLU()(data_in) + keras_model = tf.keras.models.Model(inputs=data_in, outputs=relu) + + # Load the model + converter = interpreter_wrapper.TFLiteConverter.from_keras_model(keras_model) + + # To create quantized values with dynamic range of activations, needs representative dataset + def representative_data_gen(): + for i in range(100): + yield [data] + + converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE] + converter.representative_dataset = representative_data_gen + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.uint8 + converter.inference_output_type = tf.uint8 + + # Convert the model to TensorFlow Lite format + tflite_model_quant = converter.convert() + tflite_output = run_tflite_graph(tflite_model_quant, data) + tvm_output = run_tvm_graph(tflite_model_quant, data, 'input_1') + tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), + rtol=1e-5, atol=1e-5) + else: + with tf.Graph().as_default(): + in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) + out = nn_ops.relu(in_data) + compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out]) def test_forward_relu(): """ ReLU """ _test_relu(np.arange(6.0, dtype=np.float32).reshape((1, 6))) + _test_relu(np.arange(6.0, dtype=np.float32).reshape((1, 6)), quantized=True) ####################################################################### # ReLU6 From 914ef0ad5574a1011a7dc506483805ebde90b5e2 Mon Sep 17 00:00:00 2001 From: anijain2305 Date: Sun, 28 Jun 2020 18:15:44 +0000 Subject: [PATCH 4/9] Using web-data. Minor refactoring. --- tests/python/frontend/tflite/test_forward.py | 142 +++++-------------- 1 file changed, 36 insertions(+), 106 deletions(-) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index a07696624b18..e618ce7b2e0e 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -129,6 +129,19 @@ def vmobj_to_list(o): else: raise RuntimeError("Unknown object type: %s" % type(o)) + +def _quantize_keras_model(keras_model, representative_data_gen): + """Utility function to quantize a Keras model using TFLite converter.""" + converter = interpreter_wrapper.TFLiteConverter.from_keras_model(keras_model) + converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE] + converter.representative_dataset = representative_data_gen + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.uint8 + converter.inference_output_type = tf.uint8 + converter = interpreter_wrapper.TFLiteConverter.from_keras_model(keras_model) + return converter.convert() + + def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target='llvm', out_names=None, mode='graph_runtime'): """ Generic function to compile on relay and execute on tvm """ @@ -756,21 +769,12 @@ def _test_tflite2_quantized_convolution(input_shape, kernel_shape, keras_model = tf.keras.models.Model(data_in, conv) keras_model.layers[1].set_weights([kernel]) - converter = interpreter_wrapper.TFLiteConverter.from_keras_model(keras_model) - # To create quantized values with dynamic range of activations, needs representative dataset def representative_data_gen(): for i in range(1): yield [data] - converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE] - converter.representative_dataset = representative_data_gen - converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] - converter.inference_input_type = tf.uint8 - converter.inference_output_type = tf.uint8 - - # Convert the model to TensorFlow Lite format - tflite_model_quant = converter.convert() + tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen) tflite_output = run_tflite_graph(tflite_model_quant, data) tvm_output = run_tvm_graph(tflite_model_quant, data, data_in.name.replace(":0","")) @@ -796,21 +800,14 @@ def _test_tflite2_quantized_depthwise_convolution(input_shape, kernel_shape, keras_model = tf.keras.models.Model(data_in, conv) keras_model.layers[1].set_weights([kernel]) - converter = interpreter_wrapper.TFLiteConverter.from_keras_model(keras_model) # To create quantized values with dynamic range of activations, needs representative dataset def representative_data_gen(): for i in range(1): yield [data] - converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE] - converter.representative_dataset = representative_data_gen - converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] - converter.inference_input_type = tf.uint8 - converter.inference_output_type = tf.uint8 + tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen) - # Convert the model to TensorFlow Lite format - tflite_model_quant = converter.convert() tflite_output = run_tflite_graph(tflite_model_quant, data) tvm_output = run_tvm_graph(tflite_model_quant, data, data_in.name.replace(":0","")) tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), @@ -1825,22 +1822,12 @@ def _test_quantize_dequantize(data): concat = tf.keras.layers.Concatenate(axis=0)([relu, add]) keras_model = tf.keras.models.Model(inputs=data_in, outputs=concat) - # Load the model - converter = interpreter_wrapper.TFLiteConverter.from_keras_model(keras_model) - # To create quantized values with dynamic range of activations, needs representative dataset def representative_data_gen(): - for i in range(100): + for i in range(1): yield [data] - converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE] - converter.representative_dataset = representative_data_gen - converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] - converter.inference_input_type = tf.uint8 - converter.inference_output_type = tf.uint8 - - # Convert the model to TensorFlow Lite format - tflite_model_quant = converter.convert() + tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen) tflite_output = run_tflite_graph(tflite_model_quant, data) tvm_output = run_tvm_graph(tflite_model_quant, data, 'input_1') @@ -2089,22 +2076,13 @@ def _test_relu(data, quantized=False): relu = tf.keras.layers.ReLU()(data_in) keras_model = tf.keras.models.Model(inputs=data_in, outputs=relu) - # Load the model - converter = interpreter_wrapper.TFLiteConverter.from_keras_model(keras_model) - # To create quantized values with dynamic range of activations, needs representative dataset def representative_data_gen(): - for i in range(100): + for i in range(1): yield [data] - converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE] - converter.representative_dataset = representative_data_gen - converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] - converter.inference_input_type = tf.uint8 - converter.inference_output_type = tf.uint8 + tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen) - # Convert the model to TensorFlow Lite format - tflite_model_quant = converter.convert() tflite_output = run_tflite_graph(tflite_model_quant, data) tvm_output = run_tvm_graph(tflite_model_quant, data, 'input_1') tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), @@ -2635,33 +2613,15 @@ def test_forward_qnn_mobilenet_v3_net(): tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels) -def _quantize_tf_hub_keras_model(url, height, width): - """Utility function to quantize a Keras model using TFLite converter.""" - keras_model = tf.keras.Sequential([hub.KerasLayer(url, output_shape=[1001])]) - data = pre_processed_image(height, width) - - # Set the input shapes of the keras model - keras_model._set_inputs(data) - - # Get the converter - converter = interpreter_wrapper.TFLiteConverter.from_keras_model(keras_model) - - # To create quantized values with dynamic range of activations, needs representative dataset - def representative_data_gen(): - for i in range(1): - yield [data] - - converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE] - converter.representative_dataset = representative_data_gen - return converter.convert() - - def test_forward_tflite2_qnn_resnet50(): """Test the Quantized TFLite version 2.1.0 Resnet50 model.""" if package_version.parse(tf.VERSION) >= package_version.parse('2.1.0'): - # Quantize the model - url = "https://tfhub.dev/tensorflow/resnet_50/classification/1" - tflite_model_buf = _quantize_tf_hub_keras_model(url, 224, 224) + tflite_model_file = download_testdata( + "https://raw.githubusercontent.com/dmlc/web-data/master/tensorflow/models/Quantized/resnet_50_quantized.tflite", + "resnet_50_quantized.tflite") + with open(tflite_model_file, "rb") as f: + tflite_model_buf = f.read() + data = pre_processed_image(224, 224) tflite_output = run_tflite_graph(tflite_model_buf, data) @@ -2676,43 +2636,12 @@ def test_forward_tflite2_qnn_resnet50(): def test_forward_tflite2_qnn_inception_v1(): """Test the Quantized TFLite version 2.1.0 Inception V1 model.""" if package_version.parse(tf.VERSION) >= package_version.parse('2.1.0'): - # Quantize the model - url = "https://tfhub.dev/google/imagenet/inception_v1/classification/4" - tflite_model_buf = _quantize_tf_hub_keras_model(url, 224, 224) - data = pre_processed_image(224, 224) + tflite_model_file = download_testdata( + "https://raw.githubusercontent.com/dmlc/web-data/master/tensorflow/models/Quantized/inception_v1_quantized.tflite", + "inception_v1_quantized.tflite") + with open(tflite_model_file, "rb") as f: + tflite_model_buf = f.read() - tflite_output = run_tflite_graph(tflite_model_buf, data) - tflite_predictions = np.squeeze(tflite_output) - tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1] - tvm_output = run_tvm_graph(tflite_model_buf, np.array(data), 'input_1') - tvm_predictions = np.squeeze(tvm_output) - tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1] - tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels) - - -def test_forward_tflite2_qnn_inception_v3(): - """Test the Quantized TFLite version 2.1.0 Inception V3 model.""" - if package_version.parse(tf.VERSION) >= package_version.parse('2.1.0'): - # Quantize the model - url = "https://tfhub.dev/google/imagenet/inception_v3/classification/4" - tflite_model_buf = _quantize_tf_hub_keras_model(url, 299, 299) - data = pre_processed_image(299, 299) - - tflite_output = run_tflite_graph(tflite_model_buf, data) - tflite_predictions = np.squeeze(tflite_output) - tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1] - tvm_output = run_tvm_graph(tflite_model_buf, np.array(data), 'input_1') - tvm_predictions = np.squeeze(tvm_output) - tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1] - tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels) - - -def test_forward_tflite2_qnn_mobilenet_v1(): - """Test the Quantized TFLite version 2.1.0 Mobilenet V1 model.""" - if package_version.parse(tf.VERSION) >= package_version.parse('2.1.0'): - # Quantize the model - url = "https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/4" - tflite_model_buf = _quantize_tf_hub_keras_model(url, 224, 224) data = pre_processed_image(224, 224) tflite_output = run_tflite_graph(tflite_model_buf, data) @@ -2727,9 +2656,12 @@ def test_forward_tflite2_qnn_mobilenet_v1(): def test_forward_tflite2_qnn_mobilenet_v2(): """Test the Quantized TFLite version 2.1.0 Mobilenet V2 model.""" if package_version.parse(tf.VERSION) >= package_version.parse('2.1.0'): - # Quantize the model - url = "https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/classification/4" - tflite_model_buf = _quantize_tf_hub_keras_model(url, 224, 224) + tflite_model_file = download_testdata( + "https://raw.githubusercontent.com/dmlc/web-data/master/tensorflow/models/Quantized/mobilenet_v2_quantized.tflite", + "mobilenet_v2_quantized.tflite") + with open(tflite_model_file, "rb") as f: + tflite_model_buf = f.read() + data = pre_processed_image(224, 224) tflite_output = run_tflite_graph(tflite_model_buf, data) @@ -2963,6 +2895,4 @@ def test_forward_mediapipe_hand_landmark(): # TFLite 2.1.0 quantized tests test_forward_tflite2_qnn_resnet50() test_forward_tflite2_qnn_inception_v1() - test_forward_tflite2_qnn_inception_v3() - test_forward_tflite2_qnn_mobilenet_v1() test_forward_tflite2_qnn_mobilenet_v2() From 3534b0108770804b66c0c09d4cfb676628cac2d4 Mon Sep 17 00:00:00 2001 From: anijain2305 Date: Mon, 29 Jun 2020 00:39:41 +0000 Subject: [PATCH 5/9] Removing TF hub package --- tests/python/frontend/tflite/test_forward.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index e618ce7b2e0e..0ef8cbab68e7 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -39,7 +39,6 @@ from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import nn_impl from tensorflow.python.ops import variables -import tensorflow_hub as hub try: from tensorflow import lite as interpreter_wrapper except ImportError: From e0cc5c73e01a86e6ef21b65beff2e80e087186c6 Mon Sep 17 00:00:00 2001 From: anijain2305 Date: Mon, 29 Jun 2020 03:13:32 +0000 Subject: [PATCH 6/9] Trigger CI. From cef3a853409642f82ca2df9d2573e314f106ae60 Mon Sep 17 00:00:00 2001 From: anijain2305 Date: Mon, 29 Jun 2020 06:11:21 +0000 Subject: [PATCH 7/9] Handle TFLite input layer naming. --- tests/python/frontend/tflite/test_forward.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 0ef8cbab68e7..bea0a1b6c006 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -1820,6 +1820,7 @@ def _test_quantize_dequantize(data): add = tf.keras.layers.Add()([data_in, relu]) concat = tf.keras.layers.Concatenate(axis=0)([relu, add]) keras_model = tf.keras.models.Model(inputs=data_in, outputs=concat) + input_name = data_in.name.split(":")[0] # To create quantized values with dynamic range of activations, needs representative dataset def representative_data_gen(): @@ -1829,7 +1830,7 @@ def representative_data_gen(): tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen) tflite_output = run_tflite_graph(tflite_model_quant, data) - tvm_output = run_tvm_graph(tflite_model_quant, data, 'input_1') + tvm_output = run_tvm_graph(tflite_model_quant, data, input_name) tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-2) @@ -2074,6 +2075,7 @@ def _test_relu(data, quantized=False): data_in = tf.keras.layers.Input(shape=data.shape[1:]) relu = tf.keras.layers.ReLU()(data_in) keras_model = tf.keras.models.Model(inputs=data_in, outputs=relu) + input_name = data_in.name.split(":")[0] # To create quantized values with dynamic range of activations, needs representative dataset def representative_data_gen(): @@ -2083,7 +2085,7 @@ def representative_data_gen(): tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen) tflite_output = run_tflite_graph(tflite_model_quant, data) - tvm_output = run_tvm_graph(tflite_model_quant, data, 'input_1') + tvm_output = run_tvm_graph(tflite_model_quant, data, input_name) tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-5) else: From 0518b31c260653dfe49409eb86a1655b610e11e2 Mon Sep 17 00:00:00 2001 From: anijain2305 Date: Tue, 30 Jun 2020 16:54:27 +0000 Subject: [PATCH 8/9] Addressing reviews. --- python/tvm/relay/frontend/tflite.py | 54 ++++++++++---------- tests/python/frontend/tflite/test_forward.py | 5 +- 2 files changed, 29 insertions(+), 30 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 58adf8d9f15b..36221b7467aa 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -296,40 +296,40 @@ def get_tensors(self, tensors_idx_list): return_list.append(TensorWrapper(tensor_idx, tensor, buffer, qnn_params)) return return_list - def get_tensor_value(self, tensor_wrapper): - """Get tensor buffer value from given tensor wrapper""" + + def get_tensor_type_as_numpy(self, tensor_wrapper): + """Returns np.dtype out of TensorType""" assert isinstance(tensor_wrapper, TensorWrapper) try: from tflite.TensorType import TensorType + return {TensorType.UINT8: np.uint8, + TensorType.INT8: np.int8, + TensorType.FLOAT32: np.float32, + TensorType.INT32: np.int32, + TensorType.INT64: np.int64, + TensorType.BOOL: np.bool_}[tensor_wrapper.tensor.Type()] except ImportError: raise ImportError("The tflite package must be installed") + except KeyError: + raise NotImplementedError("Tensor type '{}' currently not supported" + .format(tensor_wrapper.tensor.Type())) - # Read the data from the buffer. Also extract the shape. - # The shape is used later to reshape the data. + + def get_tensor_value(self, tensor_wrapper): + """Get tensor buffer value from given tensor wrapper""" + assert isinstance(tensor_wrapper, TensorWrapper) + + dtype = self.get_tensor_type_as_numpy(tensor_wrapper) data = tensor_wrapper.buffer.DataAsNumpy() - shape = tensor_wrapper.tensor.ShapeAsNumpy() - - # When TFLite buffer is of size 1 (scalar), then TFLite tensor shape is set to 0. - # Therefore, we set the shape to 1 for numpy reshape to work. Set shape to 1 if the data is - # a scalar type - if data.size == 1 and isinstance(shape, int) and shape == 0: - shape = (1,) - - if tensor_wrapper.tensor.Type() == TensorType.INT8: - return np.frombuffer(data, dtype=np.int8).reshape(shape) - if tensor_wrapper.tensor.Type() == TensorType.UINT8: - return np.frombuffer(data, dtype=np.uint8).reshape(shape) - if tensor_wrapper.tensor.Type() == TensorType.INT32: - return np.frombuffer(data, dtype=np.int32).reshape(shape) - if tensor_wrapper.tensor.Type() == TensorType.INT64: - return np.frombuffer(data, dtype=np.int64).reshape(shape) - if tensor_wrapper.tensor.Type() == TensorType.FLOAT32: - return np.frombuffer(data, dtype=np.float32).reshape(shape) - if tensor_wrapper.tensor.Type() == TensorType.BOOL: - return np.frombuffer(data, dtype=np.bool).reshape(shape) - raise NotImplementedError("Tensor type {} is currently not supported" - .format(str(tensor_wrapper.tensor.Type()))) + + if tensor_wrapper.tensor.ShapeLength() != 0: + shape = tensor_wrapper.tensor.ShapeAsNumpy() + else: + shape = [] + + return np.frombuffer(data, dtype=dtype).reshape(shape) + def get_tensor_type_str(self, tensor_type): """Get tensor type string representation when given TFLite tensor type""" @@ -728,7 +728,7 @@ def convert_relu(self, op): zero_point=zero_point_val, dtype=output_tensor_type_str) else: - out = _op.clip(in_expr, a_min=0, a_max=6) + out = _op.nn.relu(in_expr) if output_tensor.qnn_params: output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type()) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index bea0a1b6c006..52491b2de308 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -137,7 +137,6 @@ def _quantize_keras_model(keras_model, representative_data_gen): converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] converter.inference_input_type = tf.uint8 converter.inference_output_type = tf.uint8 - converter = interpreter_wrapper.TFLiteConverter.from_keras_model(keras_model) return converter.convert() @@ -903,7 +902,7 @@ def test_forward_convolution(): _test_convolution([4, 17, 17, 124], [1, 1, 124, 1], [1, 1], [1, 1], 'SAME', 'NHWC', True, quantized=quantized) _test_convolution([4, 17, 17, 12], [3, 3, 12, 1], [1, 1], [2, 2], 'VALID', 'NHWC', True, quantized=quantized) _test_convolution([4, 17, 17, 12], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NHWC', True, quantized=quantized) - # dephtwise convolution with single input channel + # depthwise convolution with single input channel _test_convolution([1, 76, 64, 1], [9, 5, 1, 96], [1, 1], [1, 1], 'SAME', 'NHWC', True, quantized=quantized) # TFLite2 quantized convolution testing @@ -1814,7 +1813,7 @@ def _test_quantize_dequantize(data): # Keras model to force TFLite converter to insert 2 TFLite quantize ops. # First TFLite quantize op converts float32 tensor to int8 tensor - Qnn quantize. - # Second TLite quantize op converts int8 tensor to int8 tensor - Qnn requantize. + # Second TFLite quantize op converts int8 tensor to int8 tensor - Qnn requantize. data_in = tf.keras.layers.Input(shape=data.shape[1:]) relu = tf.keras.layers.ReLU()(data_in) add = tf.keras.layers.Add()([data_in, relu]) From 81fcbed572ee9bdd4bb2d95ed188a65c92d0bfd0 Mon Sep 17 00:00:00 2001 From: anijain2305 Date: Thu, 2 Jul 2020 07:05:00 +0000 Subject: [PATCH 9/9] Retrigger CI.