@@ -230,8 +230,8 @@ def get_user_model():
230230
231231 # 3.x api
232232 if args .approach == 'weight_only' :
233- from neural_compressor .torch import RTNWeightQuantConfig , GPTQConfig , quantize
234- from neural_compressor .torch .utils . utility import get_double_quant_config
233+ from neural_compressor .torch . quantization import RTNConfig , GPTQConfig , quantize
234+ from neural_compressor .torch .utils import get_double_quant_config
235235 weight_sym = True if args .woq_scheme == "sym" else False
236236 double_quant_config_dict = get_double_quant_config (args .double_quant_type , weight_sym = weight_sym )
237237
@@ -243,9 +243,9 @@ def get_user_model():
243243 "enable_mse_search" : args .woq_enable_mse_search ,
244244 }
245245 )
246- quant_config = RTNWeightQuantConfig .from_dict (double_quant_config_dict )
246+ quant_config = RTNConfig .from_dict (double_quant_config_dict )
247247 else :
248- quant_config = RTNWeightQuantConfig (
248+ quant_config = RTNConfig (
249249 weight_dtype = args .woq_dtype ,
250250 weight_bits = args .woq_bits ,
251251 weight_group_size = args .woq_group_size ,
@@ -257,7 +257,7 @@ def get_user_model():
257257 double_quant_sym = args .double_quant_sym ,
258258 double_quant_group_size = args .double_quant_group_size ,
259259 )
260- quant_config .set_local ("lm_head" , RTNWeightQuantConfig (weight_dtype = "fp32" ))
260+ quant_config .set_local ("lm_head" , RTNConfig (weight_dtype = "fp32" ))
261261 user_model = quantize (
262262 model = user_model , quant_config = quant_config
263263 )
0 commit comments