diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 2fc82d74a08d..36221b7467aa 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -244,42 +244,92 @@ def get_tensors(self, tensors_idx_list): qnn_params = None tflite_qnn_params = tensor.Quantization() if tflite_qnn_params is not None: - scale = float(tflite_qnn_params.ScaleAsNumpy()) - zero_point = int(tflite_qnn_params.ZeroPointAsNumpy()) + # TFLite supports both per-tensor and per-axis (aka channel) quantization. For + # per-tensor quantization, scale and zero points are scalar values. For per-axis + # quantization, scale and zero points for the weights are tensors (activations are + # per-tensor quantized). However, the TFLite quantization spec puts restrictions on + # zero points for per-axis quantization. Specifically, the zero point is a tensor + # but all values are 0. More information can be found here - + # https://www.tensorflow.org/lite/performance/quantization_spec + + tflite_scale = tflite_qnn_params.ScaleAsNumpy() + tflite_zero_point = tflite_qnn_params.ZeroPointAsNumpy() + is_qnn_params_valid = True + + # Handle Per-axis and per-tensor cases + if isinstance(tflite_scale, np.ndarray): + assert isinstance(tflite_zero_point, np.ndarray) + + # Tensor - Per-axis quantization + if tflite_scale.size != 1 and tflite_zero_point.size != 1: + scale = tflite_scale + # Ensure that all zero points are zeros + zero_point = tflite_zero_point + if not np.all(zero_point == 0): + raise tvm.error.OpAttributeInvalid(\ + "TFLite per-axis quantization restricts all zero points to be" + + " 0, but a non-zero value is observed") + zero_point = int(zero_point[0]) + + # Scalar - Per-tensor quantization + elif tflite_scale.size == 1 and tflite_zero_point.size == 1: + scale = float(tflite_scale[0]) + zero_point = int(tflite_zero_point[0]) + + else: + raise NotImplementedError(\ + "Quantized type {} (scale) and {} (zero point) not supported" + .format(type(tflite_scale), type(tflite_zero_point))) + elif tflite_scale == 0 and tflite_zero_point == 0: + # Handle corner case for ops like quantized reshape whose second operand (shape) + # has zero scale and zero zero point. This is not used. + is_qnn_params_valid = False + else: + raise NotImplementedError("Quantized type {} not supported" + .format(type(tflite_scale))) + # Check that the scale and zero points are valid. - if scale != 0 or zero_point != 0: + if is_qnn_params_valid: qnn_params = dict() qnn_params['scale'] = relay.const(scale, 'float32') qnn_params['zero_point'] = relay.const(zero_point, 'int32') return_list.append(TensorWrapper(tensor_idx, tensor, buffer, qnn_params)) return return_list - def get_tensor_value(self, tensor_wrapper): - """Get tensor buffer value from given tensor wrapper""" + + def get_tensor_type_as_numpy(self, tensor_wrapper): + """Returns np.dtype out of TensorType""" assert isinstance(tensor_wrapper, TensorWrapper) try: from tflite.TensorType import TensorType + return {TensorType.UINT8: np.uint8, + TensorType.INT8: np.int8, + TensorType.FLOAT32: np.float32, + TensorType.INT32: np.int32, + TensorType.INT64: np.int64, + TensorType.BOOL: np.bool_}[tensor_wrapper.tensor.Type()] except ImportError: raise ImportError("The tflite package must be installed") + except KeyError: + raise NotImplementedError("Tensor type '{}' currently not supported" + .format(tensor_wrapper.tensor.Type())) + + + def get_tensor_value(self, tensor_wrapper): + """Get tensor buffer value from given tensor wrapper""" + assert isinstance(tensor_wrapper, TensorWrapper) + + dtype = self.get_tensor_type_as_numpy(tensor_wrapper) + data = tensor_wrapper.buffer.DataAsNumpy() + + if tensor_wrapper.tensor.ShapeLength() != 0: + shape = tensor_wrapper.tensor.ShapeAsNumpy() + else: + shape = [] + + return np.frombuffer(data, dtype=dtype).reshape(shape) - if tensor_wrapper.tensor.Type() == TensorType.UINT8: - return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.uint8).reshape( - tensor_wrapper.tensor.ShapeAsNumpy()) - if tensor_wrapper.tensor.Type() == TensorType.FLOAT32: - return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.float32).reshape( - tensor_wrapper.tensor.ShapeAsNumpy()) - if tensor_wrapper.tensor.Type() == TensorType.INT32: - return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.int32).reshape( - tensor_wrapper.tensor.ShapeAsNumpy()) - if tensor_wrapper.tensor.Type() == TensorType.INT64: - return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.int64).reshape( - tensor_wrapper.tensor.ShapeAsNumpy()) - if tensor_wrapper.tensor.Type() == TensorType.BOOL: - return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.bool_).reshape( - tensor_wrapper.tensor.ShapeAsNumpy()) - raise NotImplementedError("Tensor type {} is currently not supported" - .format(str(tensor_wrapper.tensor.Type()))) def get_tensor_type_str(self, tensor_type): """Get tensor type string representation when given TFLite tensor type""" @@ -651,12 +701,43 @@ def convert_shape(self, op): def convert_relu(self, op): """Convert TFLite ReLU""" + try: + from tflite.ActivationFunctionType import ActivationFunctionType + except ImportError: + raise ImportError("The tflite package must be installed") + input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 1, "input tensors length should be 1" - input_tensor = input_tensors[0] in_expr = self.get_expr(input_tensor.tensor_idx) - out = _op.nn.relu(in_expr) + + output_tensors = self.get_output_tensors(op) + assert len(output_tensors) == 1, "output tensors length should be 1" + output_tensor = output_tensors[0] + + if input_tensor.qnn_params: + # Quantize a float value to an quantized integer value + scale_val = get_scalar_from_constant(input_tensor.qnn_params['scale']) + zero_point_val = get_scalar_from_constant(input_tensor.qnn_params['zero_point']) + + output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type()) + out = self.convert_qnn_fused_activation_function(\ + expr=in_expr, + fused_activation_fn=ActivationFunctionType.RELU, + scale=scale_val, + zero_point=zero_point_val, + dtype=output_tensor_type_str) + else: + out = _op.nn.relu(in_expr) + + if output_tensor.qnn_params: + output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type()) + out = _qnn.op.requantize(out, + input_scale=input_tensor.qnn_params['scale'], + input_zero_point=input_tensor.qnn_params['zero_point'], + output_scale=output_tensor.qnn_params['scale'], + output_zero_point=output_tensor.qnn_params['zero_point'], + out_dtype=output_tensor_type_str) return out @@ -692,6 +773,11 @@ def _hard_swish(data): def convert_relu6(self, op): """Convert TFLite ReLU6""" + try: + from tflite.ActivationFunctionType import ActivationFunctionType + except ImportError: + raise ImportError("The tflite package must be installed") + input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 1, "input tensors length should be 1" input_tensor = input_tensors[0] @@ -705,17 +791,14 @@ def convert_relu6(self, op): # Quantize a float value to an quantized integer value scale_val = get_scalar_from_constant(input_tensor.qnn_params['scale']) zero_point_val = get_scalar_from_constant(input_tensor.qnn_params['zero_point']) - quantize = lambda x: float(int(round(x / scale_val)) + zero_point_val) - # Get min/max of the input dtype. This will be used to ensure that - # clip a_min/a_max are not beyond the dtype range. - input_tensor_type_str = self.get_tensor_type_str(input_tensor.tensor.Type()) - qmin = float(tvm.tir.op.min_value(input_tensor_type_str).value) - qmax = float(tvm.tir.op.max_value(input_tensor_type_str).value) - - out = _op.clip(in_expr, - a_min=max(qmin, quantize(0)), - a_max=min(qmax, quantize(6.0))) + output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type()) + out = self.convert_qnn_fused_activation_function(\ + expr=in_expr, + fused_activation_fn=ActivationFunctionType.RELU6, + scale=scale_val, + zero_point=zero_point_val, + dtype=output_tensor_type_str) else: out = _op.clip(in_expr, a_min=0, a_max=6) @@ -1604,9 +1687,9 @@ def convert_fully_connected(self, op): fully_connected_options.Init(op_options.Bytes, op_options.Pos) fused_activation_fn = fully_connected_options.FusedActivationFunction() - # weight tensor type should be UINT8 (quantization) or FLOAT32 + # weight tensor type should be INT8/UINT8 (quantization) or FLOAT32 weight_tensor_type = weight_tensor.tensor.Type() - assert weight_tensor_type in (TensorType.UINT8, TensorType.FLOAT32) + assert weight_tensor_type in (TensorType.INT8, TensorType.UINT8, TensorType.FLOAT32) weight_tensor_type_str = self.get_tensor_type_str(weight_tensor_type) if self.has_expr(weight_tensor.tensor_idx): @@ -1795,9 +1878,9 @@ def convert_conv(self, op, conv_type): params['channels'] = int(output_channels) params['kernel_layout'] = 'HWIO' - # weight tensor type should be UINT8 (quantization) or FLOAT32 + # weight tensor type should be INT8/UINT8 (quantization) or FLOAT32 weight_tensor_type = weight_tensor.tensor.Type() - assert weight_tensor_type in (TensorType.UINT8, TensorType.FLOAT32) + assert weight_tensor_type in (TensorType.INT8, TensorType.UINT8, TensorType.FLOAT32) weight_tensor_type_str = self.get_tensor_type_str(weight_tensor_type) in_expr = self.get_expr(input_tensor_idx) @@ -1856,9 +1939,15 @@ def convert_conv(self, op, conv_type): if output_tensor.qnn_params: # Calculate the intermediate scale and zero point of the int32 output. data_scale = input_tensor.qnn_params['scale'] - weight_scale = weight_tensor.qnn_params['scale'] data_scale_val = get_scalar_from_constant(data_scale) - weight_scale_val = get_scalar_from_constant(weight_scale) + + weight_scale = weight_tensor.qnn_params['scale'] + # If weight scale is scalar, it is per-tensor quantization + if isinstance(weight_scale, float): + weight_scale_val = get_scalar_from_constant(weight_scale) + else: + weight_scale_val = get_tensor_from_constant(weight_scale) + new_input_scale_val = data_scale_val * weight_scale_val new_input_scale = relay.const(new_input_scale_val, 'float32') new_input_zero_point = relay.const(0, 'int32') @@ -1869,7 +1958,8 @@ def convert_conv(self, op, conv_type): input_zero_point=new_input_zero_point, output_scale=output_tensor.qnn_params['scale'], output_zero_point=output_tensor.qnn_params['zero_point'], - out_dtype=output_tensor_type_str) + out_dtype=output_tensor_type_str, + axis=3) # Call activation function output_scale_val = get_scalar_from_constant(output_tensor.qnn_params['scale']) @@ -1882,7 +1972,6 @@ def convert_conv(self, op, conv_type): dtype=output_tensor_type_str) else: out = self.convert_fused_activation_function(out, fused_activation_fn) - return out def convert_split(self, op): @@ -2594,17 +2683,27 @@ def convert_quantize(self, op): input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 1, "input tensors length should be 1" input_tensor = input_tensors[0] + input_tensor_type_str = self.get_tensor_type_str(input_tensor.tensor.Type()) in_expr = self.get_expr(input_tensor.tensor_idx) output_tensors = self.get_output_tensors(op) assert len(output_tensors) == 1, "output tensors length should be 1" output_tensor = output_tensors[0] + output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type()) # The output must be quantized assert output_tensor.qnn_params - # Quantize the input - out = self.quantize(in_expr, output_tensor) + # TFLite Quantize op can also act as Requantize op + if input_tensor_type_str == "float32": + out = self.quantize(in_expr, output_tensor) + else: + out = _qnn.op.requantize(in_expr, + input_scale=input_tensor.qnn_params['scale'], + input_zero_point=input_tensor.qnn_params['zero_point'], + output_scale=output_tensor.qnn_params['scale'], + output_zero_point=output_tensor.qnn_params['zero_point'], + out_dtype=output_tensor_type_str) return out def convert_dequantize(self, op): @@ -2734,7 +2833,6 @@ def get_tensor_expr(self, tensor): else: type_str = self.get_tensor_type_str(tensor.tensor.Type()) expr = self.exp_tab.new_const(self.get_tensor_value(tensor), dtype=type_str) - return expr @@ -2747,6 +2845,13 @@ def get_scalar_from_constant(expr): "value must be float32/int32" return np.asscalar(value) +def get_tensor_from_constant(expr): + """ Returns tensor of values from Relay constant node. """ + assert isinstance(expr, _expr.Constant) + value = expr.data.asnumpy() + assert value.dtype == np.dtype(np.int32) or value.dtype == np.dtype(np.float32), \ + "value must be float32/int32" + return value def build_str_map(obj): """Build string map of TFLite enum int value diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc index 9412ab4393c5..5d2e360e0951 100644 --- a/src/relay/qnn/op/convolution.cc +++ b/src/relay/qnn/op/convolution.cc @@ -61,9 +61,19 @@ bool QnnConv2DRel(const Array& types, int num_inputs, const Attrs& attrs, CHECK(IsScalarType(types[3], DataType::Int(32))); // kernel_zero_point CHECK(IsScalarType(types[4], DataType::Float(32))); // input_scale // Kernel scale can be a vector of length output_channels or a scalar. - size_t axis = param->kernel_layout.find('O'); - CHECK(axis != std::string::npos) << "Kernel layout attribute is not defined"; - AssignType(types[5], DataType::Float(32), weight->shape[axis], reporter); // kernel scale + if (param->groups == 1) { + size_t axis = param->kernel_layout.find('O'); + CHECK(axis != std::string::npos) << "Kernel layout attribute is not defined"; + AssignType(types[5], DataType::Float(32), weight->shape[axis], reporter); // kernel scale + } else { + // Here, total number of output channels depend on depth multiplier. + size_t o_axis = param->kernel_layout.find('O'); + size_t i_axis = param->kernel_layout.find('I'); + CHECK(o_axis != std::string::npos || i_axis != std::string::npos) + << "Kernel layout attribute is not defined"; + AssignType(types[5], DataType::Float(32), weight->shape[i_axis] * weight->shape[o_axis], + reporter); // kernel scale + } // Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay // Conv2D infer type function. diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 166eb2740edb..52491b2de308 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -73,6 +73,25 @@ def get_real_image(im_height, im_width): data = np.reshape(x, (1, im_height, im_width, 3)) return data + +def pre_processed_image(height, width): + repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/' + img_name = 'elephant-299.jpg' + image_url = os.path.join(repo_base, img_name) + img_path = download_testdata(image_url, img_name, module='data') + image = tf.io.read_file(img_path) + image = tf.image.decode_jpeg(image, channels=3) + with tf.name_scope('eval_image'): + if image.dtype != tf.float32: + image = tf.image.convert_image_dtype(image, dtype=tf.float32) + image = tf.image.central_crop(image, central_fraction=0.875) + # Resize the image to the specified height and width. + image = tf.image.resize(image, [height, width], + align_corners=False) + image = tf.expand_dims(image, axis=0) + return image + + def get_real_image_object_detection(im_height, im_width): repo_base = 'https://github.com/dmlc/web-data/raw/master/gluoncv/detection/' img_name = 'street_small.jpg' @@ -109,6 +128,18 @@ def vmobj_to_list(o): else: raise RuntimeError("Unknown object type: %s" % type(o)) + +def _quantize_keras_model(keras_model, representative_data_gen): + """Utility function to quantize a Keras model using TFLite converter.""" + converter = interpreter_wrapper.TFLiteConverter.from_keras_model(keras_model) + converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE] + converter.representative_dataset = representative_data_gen + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.uint8 + converter.inference_output_type = tf.uint8 + return converter.convert() + + def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target='llvm', out_names=None, mode='graph_runtime'): """ Generic function to compile on relay and execute on tvm """ @@ -717,6 +748,70 @@ def test_forward_l2_pool2d(): # Convolution # ----------- + +def _test_tflite2_quantized_convolution(input_shape, kernel_shape, + dilations, strides, padding, data_format): + """ One iteration of TFLite2 quantized convolution with given shapes and attributes """ + data_format = "channels_last" if "NHWC" else "channels_first" + data = np.random.uniform(0, 1, input_shape).astype('float32') + kernel = np.random.uniform(0, 1, kernel_shape).astype('float32') + + data_in = tf.keras.layers.Input(shape=data.shape[1:]) + conv = tf.keras.layers.Conv2D(filters=kernel_shape[3], + kernel_size=(kernel_shape[0], kernel_shape[1]), + strides=strides, + padding=padding, + data_format=data_format, + activation='relu', + use_bias=False)(data_in) + keras_model = tf.keras.models.Model(data_in, conv) + keras_model.layers[1].set_weights([kernel]) + + # To create quantized values with dynamic range of activations, needs representative dataset + def representative_data_gen(): + for i in range(1): + yield [data] + + tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen) + + tflite_output = run_tflite_graph(tflite_model_quant, data) + tvm_output = run_tvm_graph(tflite_model_quant, data, data_in.name.replace(":0","")) + tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), + rtol=1e-2, atol=1e-2) + + +def _test_tflite2_quantized_depthwise_convolution(input_shape, kernel_shape, + dilations, strides, padding, data_format, depth_multiplier): + """One iteration of TFLite2 quantized depthwise convolution with given shapes and attributes""" + data_format = "channels_last" if "NHWC" else "channels_first" + data = np.random.uniform(0, 1, input_shape).astype('float32') + kernel = np.random.uniform(0, 1, kernel_shape).astype('float32') + + data_in = tf.keras.layers.Input(shape=data.shape[1:]) + conv = tf.keras.layers.DepthwiseConv2D(kernel_size=(kernel_shape[0], kernel_shape[1]), + strides=strides, + padding=padding, + data_format=data_format, + activation='relu', + use_bias=False, + depth_multiplier=depth_multiplier)(data_in) + keras_model = tf.keras.models.Model(data_in, conv) + keras_model.layers[1].set_weights([kernel]) + + + # To create quantized values with dynamic range of activations, needs representative dataset + def representative_data_gen(): + for i in range(1): + yield [data] + + tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen) + + tflite_output = run_tflite_graph(tflite_model_quant, data) + tvm_output = run_tvm_graph(tflite_model_quant, data, data_in.name.replace(":0","")) + tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), + rtol=1e-2, atol=1e-2) + + def _test_convolution(tensor_in_sizes, filter_in_sizes, dilations, strides, padding, data_format, is_depthwise=False, quantized=False): @@ -757,24 +852,38 @@ def _test_convolution(tensor_in_sizes, filter_in_sizes, data_format=data_format) if quantized: - # For now only quantized conv2d is supported - assert not is_depthwise - - # Quantized the inputs and feed them to the convolution - inq_data = tf.quantization.fake_quant_with_min_max_args(in_data, min=-100, max=100, name='inq_data') - inq_filter = tf.quantization.fake_quant_with_min_max_args(in_filter, min=-100, max=100, name='inq_filter') - out = nn_ops.conv2d(inq_data, - inq_filter, - strides=strides, - padding=padding, - data_format=data_format) - out = tf.quantization.fake_quant_with_min_max_args(out, min=-200, max=200, name="out") - - # Set the input quantization range - input_range = {'in_data': (-100, 100)} if quantized else None - - # Compare - compare_tflite_with_tvm(data_array, 'in_data', [in_data], [out], quantized=quantized, input_range=input_range) + if is_depthwise: + # Quantized the inputs and feed them to the convolution + inq_data = tf.quantization.fake_quant_with_min_max_args(in_data, min=-100, max=100, name='inq_data') + inq_filter = tf.quantization.fake_quant_with_min_max_args(in_filter, min=-100, max=100, name='inq_filter') + out = nn_ops.depthwise_conv2d_native(inq_data, + inq_filter, + strides=strides, + padding=padding, + data_format=data_format) + out = tf.quantization.fake_quant_with_min_max_args(out, min=-200, max=200, name="out") + + # Set the input quantization range + input_range = {'in_data': (-100, 100)} if quantized else None + + # Compare + compare_tflite_with_tvm(data_array, 'in_data', [in_data], [out], quantized=quantized, input_range=input_range) + else: + # Quantized the inputs and feed them to the convolution + inq_data = tf.quantization.fake_quant_with_min_max_args(in_data, min=-100, max=100, name='inq_data') + inq_filter = tf.quantization.fake_quant_with_min_max_args(in_filter, min=-100, max=100, name='inq_filter') + out = nn_ops.conv2d(inq_data, + inq_filter, + strides=strides, + padding=padding, + data_format=data_format) + out = tf.quantization.fake_quant_with_min_max_args(out, min=-200, max=200, name="out") + + # Set the input quantization range + input_range = {'in_data': (-100, 100)} if quantized else None + + # Compare + compare_tflite_with_tvm(data_array, 'in_data', [in_data], [out], quantized=quantized, input_range=input_range) else: data_array = np.reshape(data_array, tensor_in_sizes).astype('float32') compare_tflite_with_tvm(data_array, 'in_data', [in_data], [out]) @@ -787,14 +896,30 @@ def test_forward_convolution(): _test_convolution([4, 17, 17, 124], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME', 'NHWC', quantized=quantized) _test_convolution([4, 17, 17, 12], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID', 'NHWC', quantized=quantized) - # depthwise convolution - _test_convolution([4, 8, 8, 176], [1, 1, 176, 1], [1, 1], [1, 1], 'SAME', 'NHWC', True) - _test_convolution([4, 17, 17, 19], [3, 3, 19, 1], [1, 1], [2, 2], 'VALID', 'NHWC', True) - _test_convolution([4, 17, 17, 124], [1, 1, 124, 1], [1, 1], [1, 1], 'SAME', 'NHWC', True) - _test_convolution([4, 17, 17, 12], [3, 3, 12, 1], [1, 1], [2, 2], 'VALID', 'NHWC', True) - _test_convolution([4, 17, 17, 12], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NHWC', True) - # dephtwise convolution with single input channel - _test_convolution([1, 76, 64, 1], [9, 5, 1, 96], [1, 1], [1, 1], 'SAME', 'NHWC', True) + # depthwise convolution + _test_convolution([4, 8, 8, 176], [1, 1, 176, 1], [1, 1], [1, 1], 'SAME', 'NHWC', True, quantized=quantized) + _test_convolution([4, 17, 17, 19], [3, 3, 19, 1], [1, 1], [2, 2], 'VALID', 'NHWC', True, quantized=quantized) + _test_convolution([4, 17, 17, 124], [1, 1, 124, 1], [1, 1], [1, 1], 'SAME', 'NHWC', True, quantized=quantized) + _test_convolution([4, 17, 17, 12], [3, 3, 12, 1], [1, 1], [2, 2], 'VALID', 'NHWC', True, quantized=quantized) + _test_convolution([4, 17, 17, 12], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NHWC', True, quantized=quantized) + # depthwise convolution with single input channel + _test_convolution([1, 76, 64, 1], [9, 5, 1, 96], [1, 1], [1, 1], 'SAME', 'NHWC', True, quantized=quantized) + + # TFLite2 quantized convolution testing + if package_version.parse(tf.VERSION) >= package_version.parse('2.1.0'): + _test_tflite2_quantized_convolution([1, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', 'NHWC') + _test_tflite2_quantized_convolution([1, 17, 17, 12], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID', 'NHWC') + _test_tflite2_quantized_convolution([1, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', 'NHWC') + _test_tflite2_quantized_convolution([1, 17, 17, 124], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME', 'NHWC') + + # depthwise convolution + _test_tflite2_quantized_depthwise_convolution([1, 8, 8, 128], [1, 1, 128, 1], [1, 1], [1, 1], + 'SAME', 'NHWC', 1) + _test_tflite2_quantized_depthwise_convolution([1, 17, 17, 12], [3, 3, 12, 1], [1, 1], [2, 2], + 'VALID', 'NHWC', 1) + _test_tflite2_quantized_depthwise_convolution([1, 24, 24, 3], [7, 7, 3, 8], [1, 1], [2, 2], + 'SAME', 'NHWC', 8) + ####################################################################### @@ -1686,38 +1811,33 @@ def test_forward_squeeze(): def _test_quantize_dequantize(data): """ One iteration of quantize and dequantize """ - # Define a dummy model + # Keras model to force TFLite converter to insert 2 TFLite quantize ops. + # First TFLite quantize op converts float32 tensor to int8 tensor - Qnn quantize. + # Second TFLite quantize op converts int8 tensor to int8 tensor - Qnn requantize. data_in = tf.keras.layers.Input(shape=data.shape[1:]) - act_func = tf.keras.layers.Activation('linear') - keras_model = tf.keras.models.Model(data_in, act_func(data_in)) - - # Load the model - converter = interpreter_wrapper.TFLiteConverter.from_keras_model(keras_model) + relu = tf.keras.layers.ReLU()(data_in) + add = tf.keras.layers.Add()([data_in, relu]) + concat = tf.keras.layers.Concatenate(axis=0)([relu, add]) + keras_model = tf.keras.models.Model(inputs=data_in, outputs=concat) + input_name = data_in.name.split(":")[0] # To create quantized values with dynamic range of activations, needs representative dataset def representative_data_gen(): - for i in range(100): + for i in range(1): yield [data] - converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE] - converter.representative_dataset = representative_data_gen - converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] - converter.inference_input_type = tf.uint8 - converter.inference_output_type = tf.uint8 - - # Convert the model to TensorFlow Lite format - tflite_model_quant = converter.convert() + tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen) tflite_output = run_tflite_graph(tflite_model_quant, data) - tvm_output = run_tvm_graph(tflite_model_quant, data, 'input_1') + tvm_output = run_tvm_graph(tflite_model_quant, data, input_name) tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), - rtol=1e-5, atol=1e-5) + rtol=1e-5, atol=1e-2) def test_forward_quantize_dequantize(): """ Quantize Dequantize """ data = np.random.uniform(0, 1, (1, 4, 4, 3)).astype("float32") - if package_version.parse(tf.VERSION) >= package_version.parse('2.0.0'): + if package_version.parse(tf.VERSION) >= package_version.parse('2.1.0'): _test_quantize_dequantize(data) @@ -1945,16 +2065,38 @@ def test_forward_tanh(): # ReLu # ---- -def _test_relu(data): +def _test_relu(data, quantized=False): """ One iteration of ReLU """ - with tf.Graph().as_default(): - in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) - out = nn_ops.relu(in_data) - compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out]) + + if quantized: + if package_version.parse(tf.VERSION) < package_version.parse('2.1.0'): + pytest.skip("Testcase requires tflite version >= 2.1.0") + data_in = tf.keras.layers.Input(shape=data.shape[1:]) + relu = tf.keras.layers.ReLU()(data_in) + keras_model = tf.keras.models.Model(inputs=data_in, outputs=relu) + input_name = data_in.name.split(":")[0] + + # To create quantized values with dynamic range of activations, needs representative dataset + def representative_data_gen(): + for i in range(1): + yield [data] + + tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen) + + tflite_output = run_tflite_graph(tflite_model_quant, data) + tvm_output = run_tvm_graph(tflite_model_quant, data, input_name) + tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), + rtol=1e-5, atol=1e-5) + else: + with tf.Graph().as_default(): + in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) + out = nn_ops.relu(in_data) + compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out]) def test_forward_relu(): """ ReLU """ _test_relu(np.arange(6.0, dtype=np.float32).reshape((1, 6))) + _test_relu(np.arange(6.0, dtype=np.float32).reshape((1, 6)), quantized=True) ####################################################################### # ReLU6 @@ -2471,6 +2613,66 @@ def test_forward_qnn_mobilenet_v3_net(): tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels) +def test_forward_tflite2_qnn_resnet50(): + """Test the Quantized TFLite version 2.1.0 Resnet50 model.""" + if package_version.parse(tf.VERSION) >= package_version.parse('2.1.0'): + tflite_model_file = download_testdata( + "https://raw.githubusercontent.com/dmlc/web-data/master/tensorflow/models/Quantized/resnet_50_quantized.tflite", + "resnet_50_quantized.tflite") + with open(tflite_model_file, "rb") as f: + tflite_model_buf = f.read() + + data = pre_processed_image(224, 224) + + tflite_output = run_tflite_graph(tflite_model_buf, data) + tflite_predictions = np.squeeze(tflite_output) + tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1] + tvm_output = run_tvm_graph(tflite_model_buf, np.array(data), 'input_1') + tvm_predictions = np.squeeze(tvm_output) + tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1] + tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels) + + +def test_forward_tflite2_qnn_inception_v1(): + """Test the Quantized TFLite version 2.1.0 Inception V1 model.""" + if package_version.parse(tf.VERSION) >= package_version.parse('2.1.0'): + tflite_model_file = download_testdata( + "https://raw.githubusercontent.com/dmlc/web-data/master/tensorflow/models/Quantized/inception_v1_quantized.tflite", + "inception_v1_quantized.tflite") + with open(tflite_model_file, "rb") as f: + tflite_model_buf = f.read() + + data = pre_processed_image(224, 224) + + tflite_output = run_tflite_graph(tflite_model_buf, data) + tflite_predictions = np.squeeze(tflite_output) + tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1] + tvm_output = run_tvm_graph(tflite_model_buf, np.array(data), 'input_1') + tvm_predictions = np.squeeze(tvm_output) + tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1] + tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels) + + +def test_forward_tflite2_qnn_mobilenet_v2(): + """Test the Quantized TFLite version 2.1.0 Mobilenet V2 model.""" + if package_version.parse(tf.VERSION) >= package_version.parse('2.1.0'): + tflite_model_file = download_testdata( + "https://raw.githubusercontent.com/dmlc/web-data/master/tensorflow/models/Quantized/mobilenet_v2_quantized.tflite", + "mobilenet_v2_quantized.tflite") + with open(tflite_model_file, "rb") as f: + tflite_model_buf = f.read() + + data = pre_processed_image(224, 224) + + tflite_output = run_tflite_graph(tflite_model_buf, data) + tflite_predictions = np.squeeze(tflite_output) + tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1] + tvm_output = run_tvm_graph(tflite_model_buf, np.array(data), 'input_1') + tvm_predictions = np.squeeze(tvm_output) + tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1] + tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels) + + ####################################################################### # Quantized SSD Mobilenet # ----------------------- @@ -2689,3 +2891,8 @@ def test_forward_mediapipe_hand_landmark(): #with Tflite 1.15.2 test_forward_qnn_mobilenet_v3_net() test_forward_qnn_coco_ssd_mobilenet_v1() + + # TFLite 2.1.0 quantized tests + test_forward_tflite2_qnn_resnet50() + test_forward_tflite2_qnn_inception_v1() + test_forward_tflite2_qnn_mobilenet_v2()