diff --git a/onnxruntime/python/tools/quantization/onnx_quantizer.py b/onnxruntime/python/tools/quantization/onnx_quantizer.py index 38b64a171cdc3..5423dc65f8639 100644 --- a/onnxruntime/python/tools/quantization/onnx_quantizer.py +++ b/onnxruntime/python/tools/quantization/onnx_quantizer.py @@ -46,6 +46,8 @@ def __init__(self, model, per_channel, reduce_range, mode, static, weight_qType, is_weight_int8 = weight_qType == QuantType.QInt8 self.is_weight_symmetric = is_weight_int8 if 'WeightSymmetric' not in self.extra_options else self.extra_options['WeightSymmetric'] self.is_activation_symmetric = False if 'ActivationSymmetric' not in self.extra_options else self.extra_options['ActivationSymmetric'] + self.op_types_support_per_channel_quantization = [] if 'OpTypesSupportPerChannelQuantization' not in extra_options \ + else extra_options['OpTypesSupportPerChannelQuantization'] self.input_qType = onnx_proto.TensorProto.INT8 if input_qType == QuantType.QInt8 else onnx_proto.TensorProto.UINT8 self.weight_qType = onnx_proto.TensorProto.INT8 if weight_qType == QuantType.QInt8 else onnx_proto.TensorProto.UINT8 diff --git a/onnxruntime/python/tools/quantization/operators/qdq_base_operator.py b/onnxruntime/python/tools/quantization/operators/qdq_base_operator.py index f8f5546b1512b..ebe3b7c71a789 100644 --- a/onnxruntime/python/tools/quantization/operators/qdq_base_operator.py +++ b/onnxruntime/python/tools/quantization/operators/qdq_base_operator.py @@ -19,4 +19,10 @@ def quantize(self): nodes_to_iterate = itertools.chain(node.input, node.output) for tensor_name in nodes_to_iterate: - self.quantizer.quantize_tensor(tensor_name) + if self.quantizer.is_per_channel(): + if node.op_type in self.quantizer.op_types_support_per_channel_quantization : + self.quantizer.quantize_tensor_per_channel(tensor_name, self.quantizer.qdq_channel_axis) + else: + self.quantizer.quantize_tensor(tensor_name) + else: + self.quantizer.quantize_tensor(tensor_name) diff --git a/onnxruntime/python/tools/quantization/qdq_quantizer.py b/onnxruntime/python/tools/quantization/qdq_quantizer.py index 839ee60b09a78..cd91e9d2a4f74 100644 --- a/onnxruntime/python/tools/quantization/qdq_quantizer.py +++ b/onnxruntime/python/tools/quantization/qdq_quantizer.py @@ -51,6 +51,15 @@ def __init__(self, model, per_channel, reduce_range, mode, static, weight_qType, self.add_qdq_pair_to_weight = False if 'AddQDQPairToWeight' not in extra_options \ else extra_options['AddQDQPairToWeight'] + # The default behavior is that multiple nodes can share a QDQ pair as their inputs. + # In TRT, QDQ pair can’t be shared between nodes, so it will create dedicated QDQ pairs for each node. + self.dedicated_qdq_pair = False if 'DedicatedQDQPair' not in extra_options else extra_options['DedicatedQDQPair'] + if self.dedicated_qdq_pair: + self.tensor_to_its_receiving_nodes = {} + + # Channel axis when per_channel is True + self.qdq_channel_axis = 0 if 'QDQChannelAxis' not in extra_options else extra_options['QDQChannelAxis'] + def quantize_tensor(self, tensor_name): weight = find_by_name(tensor_name, self.model.initializer()) if weight is not None: @@ -91,6 +100,14 @@ def remove_nodes(self): self.model.remove_nodes(self.nodes_to_remove) def quantize_model(self): + if self.dedicated_qdq_pair: + for node in self.model.nodes(): + if self.should_quantize(node): + for tensor_name in node.input: + if tensor_name not in self.tensor_to_its_receiving_nodes: + self.tensor_to_its_receiving_nodes[tensor_name] = [] + self.tensor_to_its_receiving_nodes[tensor_name].append(node) + for node in self.model.nodes(): if self.should_quantize(node): op_quantizer = CreateQDQQuantizer(self, node) @@ -156,30 +173,55 @@ def quantize_tensors(self): "In static mode quantization params for inputs and outputs of nodes to be quantized are required." .format(tensor_name)) - q_input = tensor_name - q_output = tensor_name + "_QuantizeLinear" - dq_input = q_output - dq_output = tensor_name + "_DequantizeLinear" - if self.model.is_graph_output(tensor_name): - q_input = tensor_name + "_QuantizeLinearInput" - dq_output = tensor_name - self.model.replace_output_of_all_nodes(tensor_name, q_input) + if self.dedicated_qdq_pair and tensor_name in self.tensor_to_its_receiving_nodes and len(self.tensor_to_its_receiving_nodes[tensor_name]) > 1: + num_dedicated_qdq_pair = len(self.tensor_to_its_receiving_nodes[tensor_name]) + for i in range(num_dedicated_qdq_pair): + postfix = str(i+1) + q_input = tensor_name + q_output = tensor_name + "_QuantizeLinear_" + postfix + dq_input = q_output + dq_output = tensor_name + "_DequantizeLinear_" + postfix + quant_node_name = tensor_name + "_QuantizeLinear_" + postfix + dequant_node_name = tensor_name + "_DequantizeLinear_" + postfix + qlinear_node = onnx.helper.make_node("QuantizeLinear", [q_input, scale_name, zp_name], + [q_output], quant_node_name) + dequant_node = onnx.helper.make_node("DequantizeLinear", + [dq_input, scale_name, zp_name], + [dq_output], + dequant_node_name) + self.model.add_nodes([qlinear_node, dequant_node]) + + node = self.tensor_to_its_receiving_nodes[tensor_name][i] + self.model.replace_node_input(node, tensor_name, dq_output) + + quantized_value = QuantizedValue(tensor_name, dq_output, scale_name, zp_name, + QuantizedValueType.Input) + self.quantized_value_map[tensor_name] = quantized_value else: - self.model.replace_input_of_all_nodes(tensor_name, dq_output) + q_input = tensor_name + q_output = tensor_name + "_QuantizeLinear" + dq_input = q_output + dq_output = tensor_name + "_DequantizeLinear" + if self.model.is_graph_output(tensor_name): + q_input = tensor_name + "_QuantizeLinearInput" + dq_output = tensor_name + self.model.replace_output_of_all_nodes(tensor_name, q_input) + else: + self.model.replace_input_of_all_nodes(tensor_name, dq_output) - quant_node_name = tensor_name + "_QuantizeLinear" - dequant_node_name = tensor_name + "_DequantizeLinear" - qlinear_node = onnx.helper.make_node("QuantizeLinear", [q_input, scale_name, zp_name], - [q_output], quant_node_name) - dequant_node = onnx.helper.make_node("DequantizeLinear", - [dq_input, scale_name, zp_name], - [dq_output], - dequant_node_name) - self.model.add_nodes([qlinear_node, dequant_node]) + quant_node_name = tensor_name + "_QuantizeLinear" + dequant_node_name = tensor_name + "_DequantizeLinear" + qlinear_node = onnx.helper.make_node("QuantizeLinear", [q_input, scale_name, zp_name], + [q_output], quant_node_name) + dequant_node = onnx.helper.make_node("DequantizeLinear", + [dq_input, scale_name, zp_name], + [dq_output], + dequant_node_name) + self.model.add_nodes([qlinear_node, dequant_node]) - quantized_value = QuantizedValue(tensor_name, dq_output, scale_name, zp_name, - QuantizedValueType.Input) - self.quantized_value_map[tensor_name] = quantized_value + quantized_value = QuantizedValue(tensor_name, dq_output, scale_name, zp_name, + QuantizedValueType.Input) + self.quantized_value_map[tensor_name] = quantized_value def quantize_bias_tensors(self): for bias_name, input_name, weight_name in self.bias_to_quantize: diff --git a/onnxruntime/python/tools/quantization/quantize.py b/onnxruntime/python/tools/quantization/quantize.py index e70a84c23ba21..0a50832b59b6a 100644 --- a/onnxruntime/python/tools/quantization/quantize.py +++ b/onnxruntime/python/tools/quantization/quantize.py @@ -194,6 +194,8 @@ def quantize_static(model_input, inserts both QuantizeLinear/DeQuantizeLinear nodes to weight. OpTypesToExcludeOutputQuantizatioin = list of op type : Default is []. If any op type is specified, it won't quantize the output of ops with this specific op types. + DedicatedQDQPair = True/False : Default is False. When inserting QDQ pair, multiple nodes can share a single QDQ pair as their inputs. + If True, it will create identical and dedicated QDQ pair for each node. ''' mode = QuantizationMode.QLinearOps diff --git a/onnxruntime/test/python/quantization/test_qdq.py b/onnxruntime/test/python/quantization/test_qdq.py index 445ff858c0018..d8d4280e37ac6 100644 --- a/onnxruntime/test/python/quantization/test_qdq.py +++ b/onnxruntime/test/python/quantization/test_qdq.py @@ -10,7 +10,7 @@ import onnx import numpy as np from onnx import helper, TensorProto -from onnxruntime.quantization import quantize_static, QuantType, QuantFormat +from onnxruntime.quantization import quantize_static, QuantType, QuantFormat, QuantizationMode, QDQQuantizer from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count, check_op_type_order class TestQDQFormat(unittest.TestCase): @@ -24,6 +24,177 @@ def input_feeds(self, n, name2shape): dr = TestDataFeeds(input_data_list) return dr +class TestQDQExtraOptions(unittest.TestCase): + def test_qdq_extra_options(self): + # (input) + # | + # Add + # | + # ReduceMean + # | + # Add + # | + # (output) + + initializers = [] + + input_tensor = helper.make_tensor_value_info('L', TensorProto.FLOAT, [5, 5]) + output_tensor = helper.make_tensor_value_info('O', TensorProto.FLOAT, [5, 5]) + + add_weight_data_1 = np.random.normal(0, 0.1, [5, 5]).astype(np.float32) + initializers.append(onnx.numpy_helper.from_array(add_weight_data_1, name="M")) + add_weight_data_2 = np.random.normal(0, 0.1, [5, 5]).astype(np.float32) + initializers.append(onnx.numpy_helper.from_array(add_weight_data_2, name="N")) + + add_node_1 = onnx.helper.make_node('Add', ['L', 'M'], ['P'], name='Add1') + reduce_mean_node = onnx.helper.make_node('ReduceMean', ['P'], ['Q'], keepdims=1, name='ReduceMean') + add_node_2 = onnx.helper.make_node('Add', ['Q', 'N'], ['O'], name='Add2') + + graph = helper.make_graph([add_node_1, reduce_mean_node, add_node_2], 'QDQ_Test_Finetune', [input_tensor], [output_tensor], initializer=initializers) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + test_model_path = './test_qdq_finetune.onnx' + onnx.save(model, test_model_path) + + compute_range = { + 'P': [0.1, 0.1], + 'Q': [0.1, 0.1], + 'M': [0.1, 0.1], + 'N': [0.1, 0.1], + 'L': [0.1, 0.1], + 'O': [0.1, 0.1], + } + + op_types_to_quantize = ['Add'] + + mode = QuantizationMode.QLinearOps + model = onnx.load_model(test_model_path, False) + quantizer = QDQQuantizer( + model, + True, #per_channel + False, #reduce_range + mode, + True, #static + QuantType.QInt8, #weight_type + QuantType.QInt8, #activation_type + compute_range, + [], #nodes_to_quantize + ['Add2'], #nodes_to_exclude + op_types_to_quantize, + {'ActivationSymmetric' : True, 'AddQDQPairToWeight' : True, 'OpTypesToExcludeOutputQuantizatioin': []}) #extra_options + quantizer.quantize_model() + qdq_model_path = './test_qdq_finetune_qdq.onnx' + quantizer.model.save_model_to_file(qdq_model_path, False) + + # QDQ pair should be added to Add1 but not Add2 + # QDQ pair shoud be added to Add1 output as well. + qdq_added_to_node_output_flag = False + for node in quantizer.model.nodes(): + if node.name == 'Add1': + for input in node.input: + self.assertTrue("DequantizeLinear" in input) + for output in node.output: + self.assertTrue("QuantizeLinear" not in output) + + if node.name == 'Add2': + for input in node.input: + self.assertTrue("DequantizeLinear" not in input) + for output in node.output: + self.assertTrue("QuantizeLinear" not in output) + + # This QuantizeLinear node should be followed by Add1 + if node.name == 'P_QuantizeLinear': + qdq_added_to_node_output_flag = True + self.assertTrue(node.input[0] is 'P') + + self.assertTrue(qdq_added_to_node_output_flag) + + + def test_qdq_extra_options_2(self): + # (input) + # | + # Add + # / | \ + # MatMul MatMul MatMul + # | | | + # (output)(output)(output) + + initializers = [] + + input_tensor = helper.make_tensor_value_info('L', TensorProto.FLOAT, [5, 5]) + output_tensor1 = helper.make_tensor_value_info('M', TensorProto.FLOAT, [5, 5]) + output_tensor2 = helper.make_tensor_value_info('N', TensorProto.FLOAT, [5, 5]) + output_tensor3 = helper.make_tensor_value_info('O', TensorProto.FLOAT, [5, 5]) + + add_weight_data = np.random.normal(0, 0.1, [5, 5]).astype(np.float32) + initializers.append(onnx.numpy_helper.from_array(add_weight_data, name="P")) + matmul_weight_data_1 = np.random.normal(0, 0.1, [5, 5]).astype(np.float32) + initializers.append(onnx.numpy_helper.from_array(matmul_weight_data_1, name="Q")) + matmul_weight_data_2 = np.random.normal(0, 0.1, [5, 5]).astype(np.float32) + initializers.append(onnx.numpy_helper.from_array(matmul_weight_data_2, name="R")) + matmul_weight_data_3 = np.random.normal(0, 0.1, [5, 5]).astype(np.float32) + initializers.append(onnx.numpy_helper.from_array(matmul_weight_data_2, name="S")) + + add_node = onnx.helper.make_node('Add', ['L', 'P'], ['T'], name='Add') + matmul_node_1 = onnx.helper.make_node('MatMul', ['T', 'Q'], ['M'], name='MatMul1') + matmul_node_2 = onnx.helper.make_node('MatMul', ['T', 'R'], ['N'], name='MatMul2') + matmul_node_3 = onnx.helper.make_node('MatMul', ['T', 'S'], ['O'], name='MatMul3') + + graph = helper.make_graph([add_node, matmul_node_1, matmul_node_2, matmul_node_3], 'QDQ_Test_Finetune_2', [input_tensor], [output_tensor1, output_tensor2, output_tensor3], initializer=initializers) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + test_model_path = './test_qdq_finetune_2.onnx' + onnx.save(model, test_model_path) + + compute_range = { + 'L': [0.1, 0.1], + 'M': [0.1, 0.1], + 'N': [0.1, 0.1], + 'O': [0.1, 0.1], + 'P': [0.1, 0.1], + 'Q': [0.1, 0.1], + 'R': [0.1, 0.1], + 'S': [0.1, 0.1], + 'T': [0.1, 0.1], + } + + op_types_to_quantize = ['Add', 'MatMul'] + + mode = QuantizationMode.QLinearOps + model = onnx.load_model(test_model_path, False) + quantizer = QDQQuantizer( + model, + True, #per_channel + False, #reduce_range + mode, + True, #static + QuantType.QInt8, #weight_type + QuantType.QInt8, #activation_type + compute_range, + [], #nodes_to_quantize + ['Add'], #nodes_to_exclude + op_types_to_quantize, + {'ActivationSymmetric' : True, 'AddQDQPairToWeight' : True, 'OpTypesToExcludeOutputQuantizatioin': op_types_to_quantize, 'DedicatedQDQPair': True}) #extra_options + quantizer.quantize_model() + qdq_model_path = './test_qdq_finetune_qdq_2.onnx' + quantizer.model.save_model_to_file(qdq_model_path, False) + + # Three dedicated QDQ pair should be generated and feed into each MatMul node + # Also QDQ pair should not be added to Add node + # QDQ pair shoud not be added to node's output + for node in quantizer.model.nodes(): + if node.name == 'MatMul1': + self.assertTrue("T_DequantizeLinear_1" in node.input) + if node.name == 'MatMul2': + self.assertTrue("T_DequantizeLinear_2" in node.input) + if node.name == 'MatMul3': + self.assertTrue("T_DequantizeLinear_3" in node.input) + if node.name == 'Add': + for input in node.input: + self.assertTrue("DequantizeLinear" not in input) + + # QDQ pair shoud not be added to MatMul's output + if node.op_type == 'QuantizeLinear': + self.assertTrue(node.input[0] not in ['M_QuantizeLinearInput', 'N_QuantizeLinearInput', 'O_QuantizeLinearInput']) + class TestQDQFormatConv(TestQDQFormat): def construct_model_conv(self, output_model_path, input_shape, weight_shape, output_shape, has_bias): # (input)