2424import  torch 
2525
2626from  neural_compressor .common .base_config  import  BaseConfig , config_registry , register_config 
27- from  neural_compressor .common .utility  import  (
28-     DEFAULT_WHITE_LIST ,
29-     FP8_QUANT ,
30-     GPTQ ,
31-     OP_NAME_OR_MODULE_TYPE ,
32-     RTN_WEIGHT_ONLY_QUANT ,
33- )
27+ from  neural_compressor .common .utility  import  DEFAULT_WHITE_LIST , FP8_QUANT , GPTQ , OP_NAME_OR_MODULE_TYPE , RTN 
3428from  neural_compressor .torch .utils .constants  import  PRIORITY_GPTQ , PRIORITY_RTN 
3529from  neural_compressor .torch .utils .utility  import  is_hpex_avaliable , logger 
3630
@@ -60,8 +54,8 @@ class OperatorConfig(NamedTuple):
6054######################## RNT Config ############################### 
6155
6256
63- @register_config (framework_name = FRAMEWORK_NAME , algo_name = RTN_WEIGHT_ONLY_QUANT , priority = PRIORITY_RTN ) 
64- class  RTNWeightQuantConfig (BaseConfig ):
57+ @register_config (framework_name = FRAMEWORK_NAME , algo_name = RTN , priority = PRIORITY_RTN ) 
58+ class  RTNConfig (BaseConfig ):
6559    """Config class for round-to-nearest weight-only quantization.""" 
6660
6761    supported_configs : List [OperatorConfig ] =  []
@@ -80,7 +74,7 @@ class RTNWeightQuantConfig(BaseConfig):
8074        "double_quant_sym" ,
8175        "double_quant_group_size" ,
8276    ]
83-     name  =  RTN_WEIGHT_ONLY_QUANT 
77+     name  =  RTN 
8478
8579    def  __init__ (
8680        self ,
@@ -137,12 +131,12 @@ def to_dict(self):
137131
138132    @classmethod  
139133    def  from_dict (cls , config_dict ):
140-         return  super (RTNWeightQuantConfig , cls ).from_dict (config_dict = config_dict , str2operator = str2operator )
134+         return  super (RTNConfig , cls ).from_dict (config_dict = config_dict , str2operator = str2operator )
141135
142136    @classmethod  
143137    def  register_supported_configs (cls ) ->  List [OperatorConfig ]:
144138        supported_configs  =  []
145-         linear_rtn_config  =  RTNWeightQuantConfig (
139+         linear_rtn_config  =  RTNConfig (
146140            weight_dtype = ["int" , "int8" , "int4" , "nf4" , "fp4" , "fp4_e2m1_bnb" , "fp4_e2m1" ],
147141            weight_bits = [4 , 1 , 2 , 3 , 5 , 6 , 7 , 8 ],
148142            weight_group_size = [32 , - 1 , 1 , 4 , 8 , 16 , 64 , 128 , 256 , 512 , 1024 ],
@@ -173,16 +167,16 @@ def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]:
173167
174168
175169# TODO(Yi) run `register_supported_configs` for all registered config. 
176- RTNWeightQuantConfig .register_supported_configs ()
170+ RTNConfig .register_supported_configs ()
177171
178172
179- def  get_default_rtn_config () ->  RTNWeightQuantConfig :
173+ def  get_default_rtn_config () ->  RTNConfig :
180174    """Generate the default rtn config. 
181175
182176    Returns: 
183177        the default rtn config. 
184178    """ 
185-     return  RTNWeightQuantConfig ()
179+     return  RTNConfig ()
186180
187181
188182######################## GPTQ Config ############################### 
0 commit comments