2121import os
2222from copy import deepcopy
2323from pathlib import Path
24- from typing import Union , Callable , List
24+ from typing import Callable , List , Union
2525
2626import onnx
2727import onnxruntime as ort
2828
29+ from neural_compressor .common import Logger
2930from neural_compressor .onnxrt .quantization .calibrate import CalibrationDataReader
3031from neural_compressor .onnxrt .utils .onnx_model import ONNXModel
3132from neural_compressor .onnxrt .utils .utility import check_model_with_infer_shapes
32- from neural_compressor .common import Logger
3333
3434logger = Logger ().get_logger ()
3535
3636__all__ = [
3737 "layer_wise_quant" ,
3838]
3939
40+
4041def layer_wise_quant (
4142 model : Union [onnx .ModelProto , ONNXModel , Path , str ],
4243 quant_func : Callable ,
@@ -95,9 +96,7 @@ def layer_wise_quant(
9596 )
9697 )
9798 logger .debug (
98- "Will split model with these nodes for layer-wise quantization: {}" .format (
99- [node .name for node in split_nodes ]
100- )
99+ "Will split model with these nodes for layer-wise quantization: {}" .format ([node .name for node in split_nodes ])
101100 )
102101
103102 split_idx = 1
@@ -129,8 +128,12 @@ def layer_wise_quant(
129128 logger .info ("Quantize split model {}" .format (split_idx ))
130129 if require_data_reader :
131130 # process data_reader for current split and next split
132- current_data_reader = _filter_data_reader_for_current_split_model (split_model_part_1 .model , current_data_reader )
133- next_data_reader = _prepare_data_reader_for_next_split_model (split_model_part_1 .model_path , current_data_reader , providers )
131+ current_data_reader = _filter_data_reader_for_current_split_model (
132+ split_model_part_1 .model , current_data_reader
133+ )
134+ next_data_reader = _prepare_data_reader_for_next_split_model (
135+ split_model_part_1 .model_path , current_data_reader , providers
136+ )
134137 lwq_data_reader .append (next_data_reader )
135138
136139 # perform quantization
@@ -144,18 +147,17 @@ def layer_wise_quant(
144147 else :
145148 # perform quantization
146149 split_model_part_1_quantized = quant_func (
147- split_model_part_1 ,
148- weight_config = weight_config ,
149- return_modelproto = False ,
150- ** kwargs
150+ split_model_part_1 , weight_config = weight_config , return_modelproto = False , ** kwargs
151151 )
152152
153153 # check split model is valid
154154 try :
155155 ort .InferenceSession (split_model_part_1_quantized .model .SerializeToString (), providers = providers )
156156 except Exception as e :
157- logger .error ("Layer-wise quantized model {} can't be inferred correctly. "
158- "Please check the raise exception" .format (split_idx ))
157+ logger .error (
158+ "Layer-wise quantized model {} can't be inferred correctly. "
159+ "Please check the raise exception" .format (split_idx )
160+ )
159161 raise e
160162
161163 # merge split quantized model
@@ -174,7 +176,9 @@ def layer_wise_quant(
174176 if require_data_reader :
175177 # process data_reader for current split
176178 current_data_reader = lwq_data_reader .pop (0 )
177- current_data_reader = _filter_data_reader_for_current_split_model (split_model_part_2 .model , current_data_reader )
179+ current_data_reader = _filter_data_reader_for_current_split_model (
180+ split_model_part_2 .model , current_data_reader
181+ )
178182
179183 # perform quantization
180184 split_model_part_2_quantized = quant_func (
@@ -187,18 +191,17 @@ def layer_wise_quant(
187191 else :
188192 # perform quantization
189193 split_model_part_2_quantized = quant_func (
190- split_model_part_2 ,
191- weight_config = weight_config ,
192- return_modelproto = False ,
193- ** kwargs
194+ split_model_part_2 , weight_config = weight_config , return_modelproto = False , ** kwargs
194195 )
195196
196197 # check split model is valid
197198 try :
198199 ort .InferenceSession (split_model_part_2_quantized .model .SerializeToString (), providers = providers )
199200 except Exception as e :
200- logger .error ("Layer-wise quantized model {} can't be inferred correctly. "
201- "Please check the raise exception" .format (split_idx ))
201+ logger .error (
202+ "Layer-wise quantized model {} can't be inferred correctly. "
203+ "Please check the raise exception" .format (split_idx )
204+ )
202205 raise e
203206
204207 # merge split quantized model
@@ -210,6 +213,7 @@ def layer_wise_quant(
210213
211214 # reload external data to prevent external data file path errors
212215 from onnx .external_data_helper import load_external_data_for_model
216+
213217 load_external_data_for_model (quantized_model_merged .model , os .path .dirname (quantized_model_merged .model_path ))
214218
215219 return quantized_model_merged
@@ -246,13 +250,12 @@ def _filter_data_reader_for_current_split_model(model: onnx.ModelProto, data_rea
246250 if not inputs :
247251 break
248252 filter_input = {
249- input_name : input_tensor
250- for input_name , input_tensor in inputs .items ()
251- if input_name in input_names
253+ input_name : input_tensor for input_name , input_tensor in inputs .items () if input_name in input_names
252254 }
253255 filter_inputs .append (filter_input )
254256 return DataReader (filter_inputs )
255257
258+
256259def _prepare_data_reader_for_next_split_model (
257260 model_path : str ,
258261 data_reader : CalibrationDataReader ,
0 commit comments