17
17
18
18
import torch
19
19
import torch .ao .quantization .quantizer .x86_inductor_quantizer as xiq
20
- from torch .ao .quantization .observer import HistogramObserver , MinMaxObserver
20
+ from torch .ao .quantization .observer import HistogramObserver , MinMaxObserver , PlaceholderObserver
21
21
from torch .ao .quantization .quantizer import QuantizationSpec
22
22
from torch .ao .quantization .quantizer .x86_inductor_quantizer import QuantizationConfig , X86InductorQuantizer
23
23
from typing_extensions import TypeAlias
@@ -172,30 +172,41 @@ def postprocess_model(model, mode, quantizer):
172
172
del model .quantizer
173
173
174
174
175
- def create_quant_spec_from_config (dtype , sym , granularity , algo ) -> QuantizationSpec :
175
+ def create_quant_spec_from_config (dtype , sym , granularity , algo , is_dynamic = False ) -> QuantizationSpec :
176
176
dtype_mapping : Dict [str , torch .dtype ] = {"int8" : torch .int8 , "uint8" : torch .uint8 }
177
+ select_dtype = dtype_mapping [dtype ]
178
+ min_max_mapping = {torch .int8 : (- 128 , 127 ), torch .uint8 : (0 , 255 )}
177
179
qscheme_mapping = {
178
180
"per_channel" : {True : torch .per_channel_symmetric , False : torch .per_tensor_affine },
179
181
"per_tensor" : {True : torch .per_tensor_symmetric , False : torch .per_tensor_affine },
180
182
}
181
183
observer_mapping = {
184
+ "placeholder" : PlaceholderObserver ,
182
185
"minmax" : MinMaxObserver ,
183
186
"kl" : HistogramObserver ,
184
187
}
188
+ # Force to use placeholder observer for dynamic quantization
189
+ if is_dynamic :
190
+ algo = "placeholder"
185
191
# algo
186
192
observer_or_fake_quant_ctr = observer_mapping [algo ]
187
193
# qscheme
188
194
qscheme = qscheme_mapping [granularity ][sym ]
189
195
quantization_spec = QuantizationSpec (
190
- dtype = dtype_mapping [dtype ], observer_or_fake_quant_ctr = observer_or_fake_quant_ctr , qscheme = qscheme
196
+ dtype = select_dtype ,
197
+ quant_min = min_max_mapping [select_dtype ][0 ],
198
+ quant_max = min_max_mapping [select_dtype ][1 ],
199
+ observer_or_fake_quant_ctr = observer_or_fake_quant_ctr ,
200
+ qscheme = qscheme ,
201
+ is_dynamic = is_dynamic ,
191
202
)
192
203
return quantization_spec
193
204
194
205
195
- def _map_inc_config_to_torch_quant_config (inc_config ) -> QuantizationConfig :
196
- default_quant_config = xiq .get_default_x86_inductor_quantization_config ()
206
+ def _map_inc_config_to_torch_quant_config (inc_config , is_dynamic = False ) -> QuantizationConfig :
207
+ default_quant_config = xiq .get_default_x86_inductor_quantization_config (is_dynamic = is_dynamic )
197
208
input_act_quant_spec = create_quant_spec_from_config (
198
- inc_config .act_dtype , inc_config .act_sym , inc_config .act_granularity , inc_config .act_algo
209
+ inc_config .act_dtype , inc_config .act_sym , inc_config .act_granularity , inc_config .act_algo , is_dynamic = is_dynamic
199
210
)
200
211
weight_quant_spec = create_quant_spec_from_config (
201
212
inc_config .w_dtype , inc_config .w_sym , inc_config .w_granularity , inc_config .w_algo
@@ -210,14 +221,14 @@ def _map_inc_config_to_torch_quant_config(inc_config) -> QuantizationConfig:
210
221
return quant_config
211
222
212
223
213
- def create_xiq_quantizer_from_pt2e_config (config ) -> X86InductorQuantizer :
224
+ def create_xiq_quantizer_from_pt2e_config (config , is_dynamic = False ) -> X86InductorQuantizer :
214
225
quantizer = xiq .X86InductorQuantizer ()
215
226
# set global
216
- global_config = _map_inc_config_to_torch_quant_config (config )
227
+ global_config = _map_inc_config_to_torch_quant_config (config , is_dynamic )
217
228
quantizer .set_global (global_config )
218
229
# set local
219
230
for module_or_func_name , local_config in config .local_config .items ():
220
- local_quant_config = _map_inc_config_to_torch_quant_config (local_config )
231
+ local_quant_config = _map_inc_config_to_torch_quant_config (local_config , is_dynamic )
221
232
if isinstance (module_or_func_name , torch .nn .Module ):
222
233
quantizer .set_module_type_qconfig (module_or_func_name , local_quant_config )
223
234
else :
0 commit comments