From c859447024a833a88cd3aba8b1e89bc2e62501f0 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 9 Nov 2021 20:43:15 +0000 Subject: [PATCH 1/9] Add finetuned qdq options --- .../operators/qdq_base_operator.py | 5 +- .../tools/quantization/qdq_quantizer.py | 106 ++++++++++++++---- 2 files changed, 89 insertions(+), 22 deletions(-) diff --git a/onnxruntime/python/tools/quantization/operators/qdq_base_operator.py b/onnxruntime/python/tools/quantization/operators/qdq_base_operator.py index f8f5546b1512b..5514bbac8df69 100644 --- a/onnxruntime/python/tools/quantization/operators/qdq_base_operator.py +++ b/onnxruntime/python/tools/quantization/operators/qdq_base_operator.py @@ -19,4 +19,7 @@ def quantize(self): nodes_to_iterate = itertools.chain(node.input, node.output) for tensor_name in nodes_to_iterate: - self.quantizer.quantize_tensor(tensor_name) + if self.quantizer.is_per_channel(): + self.quantizer.quantize_tensor_per_channel(tensor_name, 1) + else: + self.quantizer.quantize_tensor(tensor_name) diff --git a/onnxruntime/python/tools/quantization/qdq_quantizer.py b/onnxruntime/python/tools/quantization/qdq_quantizer.py index 839ee60b09a78..c9af948f9fe7a 100644 --- a/onnxruntime/python/tools/quantization/qdq_quantizer.py +++ b/onnxruntime/python/tools/quantization/qdq_quantizer.py @@ -51,6 +51,21 @@ def __init__(self, model, per_channel, reduce_range, mode, static, weight_qType, self.add_qdq_pair_to_weight = False if 'AddQDQPairToWeight' not in extra_options \ else extra_options['AddQDQPairToWeight'] + self.dedicated_qdq_pair = False if 'DedicatedQDQPair' not in extra_options else extra_options['DedicatedQDQPair'] + if self.dedicated_qdq_pair: + self.tensor_to_its_receiving_nodes = {} + + if "Add" in self.op_types_to_quantize: + self.add_qdq_to_add_node_followed_by_redeuce_mean_node = False if 'AddQDQToAddNodeFollowedByReduceMeanNode' not in extra_options \ + else extra_options['AddQDQToAddNodeFollowedByReduceMeanNode'] + else: + self.add_qdq_to_add_node_followed_by_redeuce_mean_node = False + + if self.add_qdq_to_add_node_followed_by_redeuce_mean_node: + self.reduce_mean_nodes = [] + self.add_nodes = [] + self.add_nodes_to_quantize = [] + def quantize_tensor(self, tensor_name): weight = find_by_name(tensor_name, self.model.initializer()) if weight is not None: @@ -90,7 +105,31 @@ def remove_node(self, node): def remove_nodes(self): self.model.remove_nodes(self.nodes_to_remove) + def pre_quantize_setup(self): + if self.add_qdq_to_add_node_followed_by_redeuce_mean_node: + for node in self.model.nodes(): + if node.op_type == "Add": + self.add_nodes.append(node) + if node.op_type == "ReduceMean": + self.reduce_mean_nodes.append(node) + + for add_node in self.add_nodes: + for reduce_mean_node in self.reduce_mean_nodes: + if add_node.output == reduce_mean_node.input: + self.add_nodes_to_quantize.append(add_node.name) + if add_node.name not in self.add_nodes_to_quantize: + self.nodes_to_exclude.append(add_node.name) + + if self.dedicated_qdq_pair: + for node in self.model.nodes(): + if self.should_quantize(node): + for tensor_name in node.input: + if tensor_name not in self.tensor_to_its_receiving_nodes: + self.tensor_to_its_receiving_nodes[tensor_name] = [] + self.tensor_to_its_receiving_nodes[tensor_name].append(node) + def quantize_model(self): + self.pre_quantize_setup() for node in self.model.nodes(): if self.should_quantize(node): op_quantizer = CreateQDQQuantizer(self, node) @@ -156,30 +195,55 @@ def quantize_tensors(self): "In static mode quantization params for inputs and outputs of nodes to be quantized are required." .format(tensor_name)) - q_input = tensor_name - q_output = tensor_name + "_QuantizeLinear" - dq_input = q_output - dq_output = tensor_name + "_DequantizeLinear" - if self.model.is_graph_output(tensor_name): - q_input = tensor_name + "_QuantizeLinearInput" - dq_output = tensor_name - self.model.replace_output_of_all_nodes(tensor_name, q_input) + if self.dedicated_qdq_pair and tensor_name in self.tensor_to_its_receiving_nodes and len(self.tensor_to_its_receiving_nodes[tensor_name]) > 1: + num_dedicated_qdq_pair = len(self.tensor_to_its_receiving_nodes[tensor_name]) + for i in range(num_dedicated_qdq_pair): + postfix = str(i+1) + q_input = tensor_name + q_output = tensor_name + "_QuantizeLinear_" + postfix + dq_input = q_output + dq_output = tensor_name + "_DequantizeLinear_" + postfix + quant_node_name = tensor_name + "_QuantizeLinear_" + postfix + dequant_node_name = tensor_name + "_DequantizeLinear_" + postfix + qlinear_node = onnx.helper.make_node("QuantizeLinear", [q_input, scale_name, zp_name], + [q_output], quant_node_name) + dequant_node = onnx.helper.make_node("DequantizeLinear", + [dq_input, scale_name, zp_name], + [dq_output], + dequant_node_name) + self.model.add_nodes([qlinear_node, dequant_node]) + + node = self.tensor_to_its_receiving_nodes[tensor_name][i] + self.model.replace_node_input(node, tensor_name, dq_output) + + quantized_value = QuantizedValue(tensor_name, dq_output, scale_name, zp_name, + QuantizedValueType.Input) + self.quantized_value_map[tensor_name] = quantized_value else: - self.model.replace_input_of_all_nodes(tensor_name, dq_output) + q_input = tensor_name + q_output = tensor_name + "_QuantizeLinear" + dq_input = q_output + dq_output = tensor_name + "_DequantizeLinear" + if self.model.is_graph_output(tensor_name): + q_input = tensor_name + "_QuantizeLinearInput" + dq_output = tensor_name + self.model.replace_output_of_all_nodes(tensor_name, q_input) + else: + self.model.replace_input_of_all_nodes(tensor_name, dq_output) - quant_node_name = tensor_name + "_QuantizeLinear" - dequant_node_name = tensor_name + "_DequantizeLinear" - qlinear_node = onnx.helper.make_node("QuantizeLinear", [q_input, scale_name, zp_name], - [q_output], quant_node_name) - dequant_node = onnx.helper.make_node("DequantizeLinear", - [dq_input, scale_name, zp_name], - [dq_output], - dequant_node_name) - self.model.add_nodes([qlinear_node, dequant_node]) + quant_node_name = tensor_name + "_QuantizeLinear" + dequant_node_name = tensor_name + "_DequantizeLinear" + qlinear_node = onnx.helper.make_node("QuantizeLinear", [q_input, scale_name, zp_name], + [q_output], quant_node_name) + dequant_node = onnx.helper.make_node("DequantizeLinear", + [dq_input, scale_name, zp_name], + [dq_output], + dequant_node_name) + self.model.add_nodes([qlinear_node, dequant_node]) - quantized_value = QuantizedValue(tensor_name, dq_output, scale_name, zp_name, - QuantizedValueType.Input) - self.quantized_value_map[tensor_name] = quantized_value + quantized_value = QuantizedValue(tensor_name, dq_output, scale_name, zp_name, + QuantizedValueType.Input) + self.quantized_value_map[tensor_name] = quantized_value def quantize_bias_tensors(self): for bias_name, input_name, weight_name in self.bias_to_quantize: From 3b92da753cc20e99ef49870a13f8da33a3547e64 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Wed, 10 Nov 2021 17:48:05 +0000 Subject: [PATCH 2/9] Add description --- onnxruntime/python/tools/quantization/qdq_quantizer.py | 8 ++++++-- onnxruntime/python/tools/quantization/quantize.py | 4 ++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/onnxruntime/python/tools/quantization/qdq_quantizer.py b/onnxruntime/python/tools/quantization/qdq_quantizer.py index c9af948f9fe7a..dee9e0c1aeebc 100644 --- a/onnxruntime/python/tools/quantization/qdq_quantizer.py +++ b/onnxruntime/python/tools/quantization/qdq_quantizer.py @@ -51,10 +51,14 @@ def __init__(self, model, per_channel, reduce_range, mode, static, weight_qType, self.add_qdq_pair_to_weight = False if 'AddQDQPairToWeight' not in extra_options \ else extra_options['AddQDQPairToWeight'] + # The default behavior is that multiple nodes can share a QDQ pair as their inputs. + # In TRT, QDQ pair can’t be shared between nodes, so it will create dedicated QDQ pairs for each node. self.dedicated_qdq_pair = False if 'DedicatedQDQPair' not in extra_options else extra_options['DedicatedQDQPair'] if self.dedicated_qdq_pair: self.tensor_to_its_receiving_nodes = {} + # In TRT, it recommended to add QDQ pair to inputs of Add node followed by ReduceMean node. + # If True and Add node is in op_types_to_quantize, other Add nodes that don't meet the requirement above won't be adding QDQ pair. if "Add" in self.op_types_to_quantize: self.add_qdq_to_add_node_followed_by_redeuce_mean_node = False if 'AddQDQToAddNodeFollowedByReduceMeanNode' not in extra_options \ else extra_options['AddQDQToAddNodeFollowedByReduceMeanNode'] @@ -105,7 +109,7 @@ def remove_node(self, node): def remove_nodes(self): self.model.remove_nodes(self.nodes_to_remove) - def pre_quantize_setup(self): + def pre_quantization_setup(self): if self.add_qdq_to_add_node_followed_by_redeuce_mean_node: for node in self.model.nodes(): if node.op_type == "Add": @@ -129,7 +133,7 @@ def pre_quantize_setup(self): self.tensor_to_its_receiving_nodes[tensor_name].append(node) def quantize_model(self): - self.pre_quantize_setup() + self.pre_quantization_setup() for node in self.model.nodes(): if self.should_quantize(node): op_quantizer = CreateQDQQuantizer(self, node) diff --git a/onnxruntime/python/tools/quantization/quantize.py b/onnxruntime/python/tools/quantization/quantize.py index e70a84c23ba21..47b2d33f14f7b 100644 --- a/onnxruntime/python/tools/quantization/quantize.py +++ b/onnxruntime/python/tools/quantization/quantize.py @@ -194,6 +194,10 @@ def quantize_static(model_input, inserts both QuantizeLinear/DeQuantizeLinear nodes to weight. OpTypesToExcludeOutputQuantizatioin = list of op type : Default is []. If any op type is specified, it won't quantize the output of ops with this specific op types. + DedicatedQDQPair = True/False : Default is False. When inserting QDQ pair, multiple nodes can share a single QDQ pair as their inputs. + If True, it will create identical and dedicated QDQ pair for each node. + AddQDQToAddNodeFollowedByReduceMeanNode = True/False : Default is False. It adds QDQ pairs to every Add node if Add op type is in op_types_to_quantize. + If True, only Add node followed by ReduceMean node is going to be added QDQ pair. ''' mode = QuantizationMode.QLinearOps From ac5405b3d87f62abccd8bfdf37980203352dd986 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Wed, 10 Nov 2021 21:21:40 +0000 Subject: [PATCH 3/9] Add unit tests --- .../test/python/quantization/test_qdq.py | 173 +++++++++++++++++- 1 file changed, 172 insertions(+), 1 deletion(-) diff --git a/onnxruntime/test/python/quantization/test_qdq.py b/onnxruntime/test/python/quantization/test_qdq.py index 445ff858c0018..95f2646044f41 100644 --- a/onnxruntime/test/python/quantization/test_qdq.py +++ b/onnxruntime/test/python/quantization/test_qdq.py @@ -10,7 +10,7 @@ import onnx import numpy as np from onnx import helper, TensorProto -from onnxruntime.quantization import quantize_static, QuantType, QuantFormat +from onnxruntime.quantization import quantize_static, QuantType, QuantFormat, QuantizationMode, QDQQuantizer from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count, check_op_type_order class TestQDQFormat(unittest.TestCase): @@ -24,6 +24,177 @@ def input_feeds(self, n, name2shape): dr = TestDataFeeds(input_data_list) return dr +class TestQDQExtraOptions(unittest.TestCase): + def test_qdq_extra_options(self): + # (input) + # | + # Add + # | + # ReduceMean + # | + # Add + # | + # (output) + + initializers = [] + + input_tensor = helper.make_tensor_value_info('L', TensorProto.FLOAT, [5, 5]) + output_tensor = helper.make_tensor_value_info('O', TensorProto.FLOAT, [5, 5]) + + add_weight_data_1 = np.random.normal(0, 0.1, [5, 5]).astype(np.float32) + initializers.append(onnx.numpy_helper.from_array(add_weight_data_1, name="M")) + add_weight_data_2 = np.random.normal(0, 0.1, [5, 5]).astype(np.float32) + initializers.append(onnx.numpy_helper.from_array(add_weight_data_2, name="N")) + + add_node_1 = onnx.helper.make_node('Add', ['L', 'M'], ['P'], name='Add1') + reduce_mean_node = onnx.helper.make_node('ReduceMean', ['P'], ['Q'], keepdims=1, name='ReduceMean') + add_node_2 = onnx.helper.make_node('Add', ['Q', 'N'], ['O'], name='Add2') + + graph = helper.make_graph([add_node_1, reduce_mean_node, add_node_2], 'QDQ_Test_Finetune', [input_tensor], [output_tensor], initializer=initializers) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + test_model_path = './test_qdq_finetune.onnx' + onnx.save(model, test_model_path) + + compute_range = { + 'P': [0.1, 0.1], + 'Q': [0.1, 0.1], + 'M': [0.1, 0.1], + 'N': [0.1, 0.1], + 'L': [0.1, 0.1], + 'O': [0.1, 0.1], + } + + op_types_to_quantize = ['Add'] + + mode = QuantizationMode.QLinearOps + model = onnx.load_model(test_model_path, False) + quantizer = QDQQuantizer( + model, + True, #per_channel + False, #reduce_range + mode, + True, #static + QuantType.QInt8, #weight_type + QuantType.QInt8, #activation_type + compute_range, + [], #nodes_to_quantize + [], #nodes_to_exclude + op_types_to_quantize, + {'ActivationSymmetric' : True, 'AddQDQPairToWeight' : True, 'AddQDQToAddNodeFollowedByReduceMeanNode': True, 'OpTypesToExcludeOutputQuantizatioin': []}) #extra_options + quantizer.quantize_model() + qdq_model_path = './test_qdq_finetune_qdq.onnx' + quantizer.model.save_model_to_file(qdq_model_path, False) + + # QDQ pair should be added to Add1 but not Add2 + # QDQ pair shoud be added to Add1 output as well. + qdq_added_to_node_output_flag = False + for node in quantizer.model.nodes(): + if node.name == 'Add1': + for input in node.input: + self.assertTrue("DequantizeLinear" in input) + for output in node.output: + self.assertTrue("QuantizeLinear" not in output) + + if node.name == 'Add2': + for input in node.input: + self.assertTrue("DequantizeLinear" not in input) + for output in node.output: + self.assertTrue("QuantizeLinear" not in output) + + # This QuantizeLinear node should be followed by Add1 + if node.name == 'P_QuantizeLinear': + qdq_added_to_node_output_flag = True + self.assertTrue(node.input[0] is 'P') + + self.assertTrue(qdq_added_to_node_output_flag) + + + def test_qdq_extra_options_2(self): + # (input) + # | + # Add + # / | \ + # MatMul MatMul MatMul + # | | | + # (output)(output)(output) + + initializers = [] + + input_tensor = helper.make_tensor_value_info('L', TensorProto.FLOAT, [5, 5]) + output_tensor1 = helper.make_tensor_value_info('M', TensorProto.FLOAT, [5, 5]) + output_tensor2 = helper.make_tensor_value_info('N', TensorProto.FLOAT, [5, 5]) + output_tensor3 = helper.make_tensor_value_info('O', TensorProto.FLOAT, [5, 5]) + + add_weight_data = np.random.normal(0, 0.1, [5, 5]).astype(np.float32) + initializers.append(onnx.numpy_helper.from_array(add_weight_data, name="P")) + matmul_weight_data_1 = np.random.normal(0, 0.1, [5, 5]).astype(np.float32) + initializers.append(onnx.numpy_helper.from_array(matmul_weight_data_1, name="Q")) + matmul_weight_data_2 = np.random.normal(0, 0.1, [5, 5]).astype(np.float32) + initializers.append(onnx.numpy_helper.from_array(matmul_weight_data_2, name="R")) + matmul_weight_data_3 = np.random.normal(0, 0.1, [5, 5]).astype(np.float32) + initializers.append(onnx.numpy_helper.from_array(matmul_weight_data_2, name="S")) + + add_node = onnx.helper.make_node('Add', ['L', 'P'], ['T'], name='Add') + matmul_node_1 = onnx.helper.make_node('MatMul', ['T', 'Q'], ['M'], name='MatMul1') + matmul_node_2 = onnx.helper.make_node('MatMul', ['T', 'R'], ['N'], name='MatMul2') + matmul_node_3 = onnx.helper.make_node('MatMul', ['T', 'S'], ['O'], name='MatMul3') + + graph = helper.make_graph([add_node, matmul_node_1, matmul_node_2, matmul_node_3], 'QDQ_Test_Finetune_2', [input_tensor], [output_tensor1, output_tensor2, output_tensor3], initializer=initializers) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + test_model_path = './test_qdq_finetune_2.onnx' + onnx.save(model, test_model_path) + + compute_range = { + 'L': [0.1, 0.1], + 'M': [0.1, 0.1], + 'N': [0.1, 0.1], + 'O': [0.1, 0.1], + 'P': [0.1, 0.1], + 'Q': [0.1, 0.1], + 'R': [0.1, 0.1], + 'S': [0.1, 0.1], + 'T': [0.1, 0.1], + } + + op_types_to_quantize = ['Add', 'MatMul'] + + mode = QuantizationMode.QLinearOps + model = onnx.load_model(test_model_path, False) + quantizer = QDQQuantizer( + model, + True, #per_channel + False, #reduce_range + mode, + True, #static + QuantType.QInt8, #weight_type + QuantType.QInt8, #activation_type + compute_range, + [], #nodes_to_quantize + [], #nodes_to_exclude + op_types_to_quantize, + {'ActivationSymmetric' : True, 'AddQDQPairToWeight' : True, 'AddQDQToAddNodeFollowedByReduceMeanNode': True, 'OpTypesToExcludeOutputQuantizatioin': op_types_to_quantize, 'DedicatedQDQPair': True}) #extra_options + quantizer.quantize_model() + qdq_model_path = './test_qdq_finetune_qdq_2.onnx' + quantizer.model.save_model_to_file(qdq_model_path, False) + + # Three dedicated QDQ pair should be generated and feed into each MatMul node + # Also QDQ pair should not be added to Add node + # QDQ pair shoud not be added to node's output + for node in quantizer.model.nodes(): + if node.name == 'MatMul1': + self.assertTrue("T_DequantizeLinear_1" in node.input) + if node.name == 'MatMul2': + self.assertTrue("T_DequantizeLinear_2" in node.input) + if node.name == 'MatMul3': + self.assertTrue("T_DequantizeLinear_3" in node.input) + if node.name == 'Add': + for input in node.input: + self.assertTrue("DequantizeLinear" not in input) + + # QDQ pair shoud not be added to MatMul's output + if node.op_type == 'QuantizeLinear': + self.assertTrue(node.input[0] not in ['M_QuantizeLinearInput', 'N_QuantizeLinearInput', 'O_QuantizeLinearInput']) + class TestQDQFormatConv(TestQDQFormat): def construct_model_conv(self, output_model_path, input_shape, weight_shape, output_shape, has_bias): # (input) From b30daea64608a9d0d8fe41ca8c335d4e76e09027 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Thu, 11 Nov 2021 00:45:21 +0000 Subject: [PATCH 4/9] Modify for channel axis --- .../python/tools/quantization/operators/qdq_base_operator.py | 2 +- onnxruntime/python/tools/quantization/qdq_quantizer.py | 3 +++ onnxruntime/python/tools/quantization/quantize.py | 1 + 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/onnxruntime/python/tools/quantization/operators/qdq_base_operator.py b/onnxruntime/python/tools/quantization/operators/qdq_base_operator.py index 5514bbac8df69..a6dbde2730786 100644 --- a/onnxruntime/python/tools/quantization/operators/qdq_base_operator.py +++ b/onnxruntime/python/tools/quantization/operators/qdq_base_operator.py @@ -20,6 +20,6 @@ def quantize(self): for tensor_name in nodes_to_iterate: if self.quantizer.is_per_channel(): - self.quantizer.quantize_tensor_per_channel(tensor_name, 1) + self.quantizer.quantize_tensor_per_channel(tensor_name, self.quantizer.qdq_channel_axis) else: self.quantizer.quantize_tensor(tensor_name) diff --git a/onnxruntime/python/tools/quantization/qdq_quantizer.py b/onnxruntime/python/tools/quantization/qdq_quantizer.py index dee9e0c1aeebc..f874d2a9092e3 100644 --- a/onnxruntime/python/tools/quantization/qdq_quantizer.py +++ b/onnxruntime/python/tools/quantization/qdq_quantizer.py @@ -57,6 +57,9 @@ def __init__(self, model, per_channel, reduce_range, mode, static, weight_qType, if self.dedicated_qdq_pair: self.tensor_to_its_receiving_nodes = {} + # Channel axis when per_channel is True + self.qdq_channel_axis = 0 if 'QDQChannelAxis' not in extra_options else extra_options['QDQChannelAxis'] + # In TRT, it recommended to add QDQ pair to inputs of Add node followed by ReduceMean node. # If True and Add node is in op_types_to_quantize, other Add nodes that don't meet the requirement above won't be adding QDQ pair. if "Add" in self.op_types_to_quantize: diff --git a/onnxruntime/python/tools/quantization/quantize.py b/onnxruntime/python/tools/quantization/quantize.py index 47b2d33f14f7b..bd645af7ef0ca 100644 --- a/onnxruntime/python/tools/quantization/quantize.py +++ b/onnxruntime/python/tools/quantization/quantize.py @@ -198,6 +198,7 @@ def quantize_static(model_input, If True, it will create identical and dedicated QDQ pair for each node. AddQDQToAddNodeFollowedByReduceMeanNode = True/False : Default is False. It adds QDQ pairs to every Add node if Add op type is in op_types_to_quantize. If True, only Add node followed by ReduceMean node is going to be added QDQ pair. + QDQChannelAxis = Integer : Default is 0. Channel axis for QDQ pair when per_channel is True. ''' mode = QuantizationMode.QLinearOps From 0ba3c58a421f72100e3d7a633b81a90523cdc315 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Thu, 18 Nov 2021 05:39:19 +0000 Subject: [PATCH 5/9] Remove too specific feature. Move this implementation to e2e example --- .../tools/quantization/qdq_quantizer.py | 31 +------------------ .../python/tools/quantization/quantize.py | 2 -- 2 files changed, 1 insertion(+), 32 deletions(-) diff --git a/onnxruntime/python/tools/quantization/qdq_quantizer.py b/onnxruntime/python/tools/quantization/qdq_quantizer.py index f874d2a9092e3..cd91e9d2a4f74 100644 --- a/onnxruntime/python/tools/quantization/qdq_quantizer.py +++ b/onnxruntime/python/tools/quantization/qdq_quantizer.py @@ -60,19 +60,6 @@ def __init__(self, model, per_channel, reduce_range, mode, static, weight_qType, # Channel axis when per_channel is True self.qdq_channel_axis = 0 if 'QDQChannelAxis' not in extra_options else extra_options['QDQChannelAxis'] - # In TRT, it recommended to add QDQ pair to inputs of Add node followed by ReduceMean node. - # If True and Add node is in op_types_to_quantize, other Add nodes that don't meet the requirement above won't be adding QDQ pair. - if "Add" in self.op_types_to_quantize: - self.add_qdq_to_add_node_followed_by_redeuce_mean_node = False if 'AddQDQToAddNodeFollowedByReduceMeanNode' not in extra_options \ - else extra_options['AddQDQToAddNodeFollowedByReduceMeanNode'] - else: - self.add_qdq_to_add_node_followed_by_redeuce_mean_node = False - - if self.add_qdq_to_add_node_followed_by_redeuce_mean_node: - self.reduce_mean_nodes = [] - self.add_nodes = [] - self.add_nodes_to_quantize = [] - def quantize_tensor(self, tensor_name): weight = find_by_name(tensor_name, self.model.initializer()) if weight is not None: @@ -112,21 +99,7 @@ def remove_node(self, node): def remove_nodes(self): self.model.remove_nodes(self.nodes_to_remove) - def pre_quantization_setup(self): - if self.add_qdq_to_add_node_followed_by_redeuce_mean_node: - for node in self.model.nodes(): - if node.op_type == "Add": - self.add_nodes.append(node) - if node.op_type == "ReduceMean": - self.reduce_mean_nodes.append(node) - - for add_node in self.add_nodes: - for reduce_mean_node in self.reduce_mean_nodes: - if add_node.output == reduce_mean_node.input: - self.add_nodes_to_quantize.append(add_node.name) - if add_node.name not in self.add_nodes_to_quantize: - self.nodes_to_exclude.append(add_node.name) - + def quantize_model(self): if self.dedicated_qdq_pair: for node in self.model.nodes(): if self.should_quantize(node): @@ -135,8 +108,6 @@ def pre_quantization_setup(self): self.tensor_to_its_receiving_nodes[tensor_name] = [] self.tensor_to_its_receiving_nodes[tensor_name].append(node) - def quantize_model(self): - self.pre_quantization_setup() for node in self.model.nodes(): if self.should_quantize(node): op_quantizer = CreateQDQQuantizer(self, node) diff --git a/onnxruntime/python/tools/quantization/quantize.py b/onnxruntime/python/tools/quantization/quantize.py index bd645af7ef0ca..1955dc687867e 100644 --- a/onnxruntime/python/tools/quantization/quantize.py +++ b/onnxruntime/python/tools/quantization/quantize.py @@ -196,8 +196,6 @@ def quantize_static(model_input, the output of ops with this specific op types. DedicatedQDQPair = True/False : Default is False. When inserting QDQ pair, multiple nodes can share a single QDQ pair as their inputs. If True, it will create identical and dedicated QDQ pair for each node. - AddQDQToAddNodeFollowedByReduceMeanNode = True/False : Default is False. It adds QDQ pairs to every Add node if Add op type is in op_types_to_quantize. - If True, only Add node followed by ReduceMean node is going to be added QDQ pair. QDQChannelAxis = Integer : Default is 0. Channel axis for QDQ pair when per_channel is True. ''' From 5f04fbe70496e6ab6f4bbcd29a6f2fef66b607fb Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Thu, 18 Nov 2021 06:38:24 +0000 Subject: [PATCH 6/9] Add OpTypesSupportPerChannelQuantization --- onnxruntime/python/tools/quantization/onnx_quantizer.py | 2 ++ .../tools/quantization/operators/qdq_base_operator.py | 9 ++++++++- onnxruntime/python/tools/quantization/quantize.py | 1 + 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/onnxruntime/python/tools/quantization/onnx_quantizer.py b/onnxruntime/python/tools/quantization/onnx_quantizer.py index 38b64a171cdc3..5423dc65f8639 100644 --- a/onnxruntime/python/tools/quantization/onnx_quantizer.py +++ b/onnxruntime/python/tools/quantization/onnx_quantizer.py @@ -46,6 +46,8 @@ def __init__(self, model, per_channel, reduce_range, mode, static, weight_qType, is_weight_int8 = weight_qType == QuantType.QInt8 self.is_weight_symmetric = is_weight_int8 if 'WeightSymmetric' not in self.extra_options else self.extra_options['WeightSymmetric'] self.is_activation_symmetric = False if 'ActivationSymmetric' not in self.extra_options else self.extra_options['ActivationSymmetric'] + self.op_types_support_per_channel_quantization = [] if 'OpTypesSupportPerChannelQuantization' not in extra_options \ + else extra_options['OpTypesSupportPerChannelQuantization'] self.input_qType = onnx_proto.TensorProto.INT8 if input_qType == QuantType.QInt8 else onnx_proto.TensorProto.UINT8 self.weight_qType = onnx_proto.TensorProto.INT8 if weight_qType == QuantType.QInt8 else onnx_proto.TensorProto.UINT8 diff --git a/onnxruntime/python/tools/quantization/operators/qdq_base_operator.py b/onnxruntime/python/tools/quantization/operators/qdq_base_operator.py index a6dbde2730786..6e9d0f294802d 100644 --- a/onnxruntime/python/tools/quantization/operators/qdq_base_operator.py +++ b/onnxruntime/python/tools/quantization/operators/qdq_base_operator.py @@ -1,6 +1,7 @@ import itertools from .base_operator import QuantOperatorBase from ..quant_utils import QuantizedValue, QuantizedValueType, attribute_to_kwarg, quantize_nparray +import logging class QDQOperatorBase: @@ -20,6 +21,12 @@ def quantize(self): for tensor_name in nodes_to_iterate: if self.quantizer.is_per_channel(): - self.quantizer.quantize_tensor_per_channel(tensor_name, self.quantizer.qdq_channel_axis) + if node.op_type in self.quantizer.op_types_support_per_channel_quantization : + self.quantizer.quantize_tensor_per_channel(tensor_name, self.quantizer.qdq_channel_axis) + else: + logging.warning( + "{} doesn't support per channel quantization. Quantize tensor: {} with per-tensor instead.".format( + node.op_type, tensor_name)) + self.quantizer.quantize_tensor(tensor_name) else: self.quantizer.quantize_tensor(tensor_name) diff --git a/onnxruntime/python/tools/quantization/quantize.py b/onnxruntime/python/tools/quantization/quantize.py index 1955dc687867e..2b00c062afc34 100644 --- a/onnxruntime/python/tools/quantization/quantize.py +++ b/onnxruntime/python/tools/quantization/quantize.py @@ -196,6 +196,7 @@ def quantize_static(model_input, the output of ops with this specific op types. DedicatedQDQPair = True/False : Default is False. When inserting QDQ pair, multiple nodes can share a single QDQ pair as their inputs. If True, it will create identical and dedicated QDQ pair for each node. + OpTypesSupportPerChannelQuantization = list of op type : Default is []. List of op types that has per channel quantization support. QDQChannelAxis = Integer : Default is 0. Channel axis for QDQ pair when per_channel is True. ''' From fa51f2676635aca02708078523fee0e9edf29551 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Thu, 18 Nov 2021 07:29:59 +0000 Subject: [PATCH 7/9] fix bug for unit test --- onnxruntime/test/python/quantization/test_qdq.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxruntime/test/python/quantization/test_qdq.py b/onnxruntime/test/python/quantization/test_qdq.py index 95f2646044f41..d8d4280e37ac6 100644 --- a/onnxruntime/test/python/quantization/test_qdq.py +++ b/onnxruntime/test/python/quantization/test_qdq.py @@ -78,9 +78,9 @@ def test_qdq_extra_options(self): QuantType.QInt8, #activation_type compute_range, [], #nodes_to_quantize - [], #nodes_to_exclude + ['Add2'], #nodes_to_exclude op_types_to_quantize, - {'ActivationSymmetric' : True, 'AddQDQPairToWeight' : True, 'AddQDQToAddNodeFollowedByReduceMeanNode': True, 'OpTypesToExcludeOutputQuantizatioin': []}) #extra_options + {'ActivationSymmetric' : True, 'AddQDQPairToWeight' : True, 'OpTypesToExcludeOutputQuantizatioin': []}) #extra_options quantizer.quantize_model() qdq_model_path = './test_qdq_finetune_qdq.onnx' quantizer.model.save_model_to_file(qdq_model_path, False) @@ -170,9 +170,9 @@ def test_qdq_extra_options_2(self): QuantType.QInt8, #activation_type compute_range, [], #nodes_to_quantize - [], #nodes_to_exclude + ['Add'], #nodes_to_exclude op_types_to_quantize, - {'ActivationSymmetric' : True, 'AddQDQPairToWeight' : True, 'AddQDQToAddNodeFollowedByReduceMeanNode': True, 'OpTypesToExcludeOutputQuantizatioin': op_types_to_quantize, 'DedicatedQDQPair': True}) #extra_options + {'ActivationSymmetric' : True, 'AddQDQPairToWeight' : True, 'OpTypesToExcludeOutputQuantizatioin': op_types_to_quantize, 'DedicatedQDQPair': True}) #extra_options quantizer.quantize_model() qdq_model_path = './test_qdq_finetune_qdq_2.onnx' quantizer.model.save_model_to_file(qdq_model_path, False) From ca35bf12e9b0085fe95ebf20a2223bccfcf29c29 Mon Sep 17 00:00:00 2001 From: stevenlix <38092805+stevenlix@users.noreply.github.com> Date: Mon, 29 Nov 2021 16:06:16 -0800 Subject: [PATCH 8/9] Keep flags OpTypesSupportPerChannelQuantization and QDQChannelAxis for internal use Will have a follow-up PR to fine tune the code --- onnxruntime/python/tools/quantization/quantize.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/onnxruntime/python/tools/quantization/quantize.py b/onnxruntime/python/tools/quantization/quantize.py index 2b00c062afc34..0a50832b59b6a 100644 --- a/onnxruntime/python/tools/quantization/quantize.py +++ b/onnxruntime/python/tools/quantization/quantize.py @@ -195,9 +195,7 @@ def quantize_static(model_input, OpTypesToExcludeOutputQuantizatioin = list of op type : Default is []. If any op type is specified, it won't quantize the output of ops with this specific op types. DedicatedQDQPair = True/False : Default is False. When inserting QDQ pair, multiple nodes can share a single QDQ pair as their inputs. - If True, it will create identical and dedicated QDQ pair for each node. - OpTypesSupportPerChannelQuantization = list of op type : Default is []. List of op types that has per channel quantization support. - QDQChannelAxis = Integer : Default is 0. Channel axis for QDQ pair when per_channel is True. + If True, it will create identical and dedicated QDQ pair for each node. ''' mode = QuantizationMode.QLinearOps From 87ca68c861823fa88d97b6201df3d3c2c7682bcd Mon Sep 17 00:00:00 2001 From: Yufeng Li Date: Mon, 29 Nov 2021 18:53:24 -0800 Subject: [PATCH 9/9] remove unnecessary warning --- .../python/tools/quantization/operators/qdq_base_operator.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/onnxruntime/python/tools/quantization/operators/qdq_base_operator.py b/onnxruntime/python/tools/quantization/operators/qdq_base_operator.py index 6e9d0f294802d..ebe3b7c71a789 100644 --- a/onnxruntime/python/tools/quantization/operators/qdq_base_operator.py +++ b/onnxruntime/python/tools/quantization/operators/qdq_base_operator.py @@ -1,7 +1,6 @@ import itertools from .base_operator import QuantOperatorBase from ..quant_utils import QuantizedValue, QuantizedValueType, attribute_to_kwarg, quantize_nparray -import logging class QDQOperatorBase: @@ -24,9 +23,6 @@ def quantize(self): if node.op_type in self.quantizer.op_types_support_per_channel_quantization : self.quantizer.quantize_tensor_per_channel(tensor_name, self.quantizer.qdq_channel_axis) else: - logging.warning( - "{} doesn't support per channel quantization. Quantize tensor: {} with per-tensor instead.".format( - node.op_type, tensor_name)) self.quantizer.quantize_tensor(tensor_name) else: self.quantizer.quantize_tensor(tensor_name)