2121
2222import torch
2323
24- from neural_compressor .torch .utils import logger
25- from neural_compressor .torch .utils .utility import set_module
24+ from neural_compressor .torch .utils import logger , set_module
2625
2726from .utility import quant_tensor , search_clip
2827
2928
3029@torch .no_grad ()
3130def rtn_quantize (
3231 model ,
33- num_bits = 4 ,
32+ dtype = "int" ,
33+ bits = 4 ,
34+ scheme = "sym" ,
3435 group_size = 32 ,
35- scheme = "asym" ,
36+ group_dim = 1 ,
3637 quantile = 1.0 ,
3738 weight_config = {},
38- return_int = False ,
39- dtype = "int" ,
40- enable_full_range = False ,
41- enable_mse_search = False ,
42- group_dim = 1 ,
39+ export_compressed_model = False ,
40+ use_full_range = False ,
41+ use_mse_search = False ,
4342 ** kwargs ,
4443):
45- """Quant the model with round to nearest method.
44+ """Quant the model with round to nearest method and inplace is True .
4645
4746 Args:
4847 model: torch module
49- num_bits : num bits. Defaults to 4.
48+ bits : num bits. Defaults to 4.
5049 group_size (int, optional): how many elements share one scale/zp. Defaults to 32.
51- scheme (str, optional): sym or asym. Defaults to "asym ".
50+ scheme (str, optional): sym or asym. Defaults to "sym ".
5251 quantile (float, optional): percentile of clip. Defaults to 1.0.
5352 dtype (str, optional): select from int, nf4, fp4. Defaults to int.
5453 weight_config (dict, optional): specific layer wise configurations. Defaults to {}.
@@ -60,88 +59,98 @@ def rtn_quantize(
6059 'bits': 4,
6160 'group_size': 32,
6261 'scheme': 'sym'
63- 'gptq_perm': [1, 1, ...] # for gptq perm
6462 }
6563 }
66- return_int (bool, optional): Choose return fp32 or int32 model.
64+ export_compressed_model (bool, optional): Choose return fp32 or int32 model.
6765 Defaults to False.
68- enable_full_range (bool, optional): Choose sym range whether use -2**(bits-1).
66+ use_full_range (bool, optional): Choose sym range whether use -2**(bits-1).
6967 Defaults to False.
70- enable_mse_search (bool, optional): Whether search clip range.
68+ use_mse_search (bool, optional): Whether search clip range.
7169 Defaults to True.
7270 group_dim (int, optional): 0 means splitting output channel,
7371 1 means splitting input channel. Defaults to 1.
7472
7573 Returns:
7674 model: fake quantized torch module
7775 """
76+ device = "cpu"
7877 assert isinstance (model , torch .nn .Module ), "only support torch module"
7978 supported_layers = ["Linear" ]
80- double_quant_dtype = kwargs . get ( "double_quant_dtype" , "fp32" )
79+ # initialize global configuration
8180 double_quant_config = {
82- "double_quant" : False if double_quant_dtype == "fp32" else True ,
83- "double_quant_dtype" : double_quant_dtype ,
84- "double_quant_num_bits " : kwargs .get ("double_quant_num_bits " , 8 ),
81+ "double_quant" : kwargs . get ( "use_double_quant" , False ) ,
82+ "double_quant_dtype" : kwargs . get ( " double_quant_dtype" , "int" ) ,
83+ "double_quant_bits " : kwargs .get ("double_quant_bits " , 8 ),
8584 "double_quant_scheme" : kwargs .get ("double_quant_scheme" , "sym" ),
8685 "double_quant_group_size" : kwargs .get ("double_quant_group_size" , 256 ),
8786 }
88- if return_int :
89- compression_dtype = kwargs .get ("compression_dtype" , torch .int32 )
90- compression_dim = kwargs .get ("compression_dim" , 1 )
91- scale_dtype = kwargs .get ("scale_dtype" , torch .float32 )
92- device = kwargs .get ("device" , "cpu" )
87+ if export_compressed_model :
88+ use_optimum_format = kwargs .get ("use_optimum_format" , True )
9389 for name , m in model .named_modules ():
9490 if m .__class__ .__name__ not in supported_layers :
9591 continue
9692 if name in weight_config : # pragma: no cover
93+ # initialize op configuration
9794 dtype = weight_config [name ].get ("dtype" , "int" )
98- num_bits = weight_config [name ][ "bits" ]
95+ bits = weight_config [name ]. get ( "bits" , 4 )
9996 group_size = weight_config [name ]["group_size" ]
10097 scheme = weight_config [name ]["scheme" ]
10198 quantile = weight_config [name ].get ("quantile" , 1.0 )
99+ group_dim = weight_config [name ]["group_dim" ]
100+ use_full_range = weight_config [name ]["use_full_range" ]
101+ use_mse_search = weight_config [name ]["use_mse_search" ]
102+ use_layer_wise = weight_config [name ]["use_layer_wise" ]
103+ export_compressed_model = weight_config [name ]["export_compressed_model" ]
104+ if export_compressed_model :
105+ use_optimum_format = kwargs .get ("use_optimum_format" , True )
106+ # double quant config
107+ double_quant_config = {
108+ "double_quant" : weight_config [name ]["use_double_quant" ],
109+ "double_quant_dtype" : weight_config [name ]["double_quant_dtype" ],
110+ "double_quant_bits" : weight_config [name ]["double_quant_bits" ],
111+ "double_quant_scheme" : weight_config [name ]["double_quant_scheme" ],
112+ "double_quant_group_size" : weight_config [name ]["double_quant_group_size" ],
113+ }
102114 log_msg = (
103- f"RTN quantization config: num_bits={ num_bits } , group_size={ group_size } , "
104- + f"scheme={ scheme } , quantile={ quantile } "
115+ f"RTN quantization config: bits={ bits } , group_size={ group_size } , " + f"scheme={ scheme } , quantile={ quantile } "
105116 )
106117 if dtype != "int" :
107118 log_msg += f", dtype={ dtype } "
108119 elif scheme == "sym" : # nf4/fp4 is always [-7,7]
109- log_msg += f", enable_full_range= { enable_full_range } "
120+ log_msg += f", use_full_range= { use_full_range } "
110121 if dtype == "fp32" :
111122 continue
112123 logger .debug (f"RTN quantized module:{ name , m } " )
113124 logger .debug (log_msg )
114- weight = m .weight .T if group_dim == 0 else m .weight
115- if enable_mse_search :
116- quantile = search_clip (m , num_bits , group_size , scheme , dtype , enable_full_range )
117- if return_int :
125+ weight = m .weight .t_ (). contiguous () if group_dim == 0 else m .weight
126+ if use_mse_search :
127+ quantile = search_clip (m , bits , group_size , scheme , dtype , use_full_range )
128+ if export_compressed_model :
118129 int_weight , scale , zp = quant_tensor (
119130 weight ,
120- num_bits ,
121- group_size ,
122- scheme ,
123- quantile ,
124131 dtype = dtype ,
132+ bits = bits ,
133+ group_size = group_size ,
134+ scheme = scheme ,
135+ quantile = quantile ,
125136 return_int = True ,
126- full_range = enable_full_range ,
137+ full_range = use_full_range ,
127138 ** double_quant_config ,
128139 )
129- int_weight = int_weight .T if group_dim == 0 else int_weight
130- scale = scale .T if group_dim == 0 else scale
131- zp = zp .T if group_dim == 0 and zp is not None else zp
140+ int_weight = int_weight .t_ (). contiguous () if group_dim == 0 else int_weight
141+ scale = scale .t_ (). contiguous () if group_dim == 0 else scale
142+ zp = zp .t_ (). contiguous () if group_dim == 0 and zp is not None else zp
132143 from neural_compressor .torch .quantization .layers import WeightOnlyLinear
133144
134145 new_module = WeightOnlyLinear (
135146 m .in_features ,
136147 m .out_features ,
137- num_bits ,
138- group_size ,
148+ bits = bits ,
149+ group_size = group_size ,
139150 dtype = dtype ,
140151 zp = zp is not None ,
141152 bias = m .bias is not None ,
142- compression_dtype = compression_dtype ,
143- compression_dim = compression_dim ,
144- scale_dtype = scale_dtype ,
153+ use_optimum_format = use_optimum_format ,
145154 device = device ,
146155 )
147156 new_module .pack (int_weight , scale , zp , m .bias )
@@ -150,50 +159,16 @@ def rtn_quantize(
150159 else :
151160 set_module (model , name , new_module )
152161 else :
153- q_weight = quant_tensor (
162+ weight = quant_tensor (
154163 weight ,
155- num_bits ,
156- group_size ,
157- scheme ,
158- quantile ,
159164 dtype = dtype ,
160- full_range = enable_full_range ,
165+ bits = bits ,
166+ group_size = group_size ,
167+ scheme = scheme ,
168+ quantile = quantile ,
169+ full_range = use_full_range ,
161170 ** double_quant_config ,
162171 )
163- q_weight = q_weight . T if group_dim == 0 else q_weight
164- m .weight .data .copy_ (q_weight )
172+ weight = weight . t_ (). contiguous () if group_dim == 0 else weight
173+ m .weight .data .copy_ (weight )
165174 return model
166-
167-
168- from neural_compressor .torch .quantization .config import RTNConfig
169-
170-
171- def apply_rtn_on_single_module (module : torch .nn .Module , quant_config : RTNConfig ) -> torch .nn .Module :
172- # TODO (Yi) remove it
173- enable_full_range = quant_config .enable_full_range
174- enable_mse_search = quant_config .enable_mse_search
175- group_dim = quant_config .group_dim
176- dtype = quant_config .weight_dtype
177- num_bits = quant_config .weight_bits
178- scheme = "sym" if quant_config .weight_sym else "asym"
179- group_size = quant_config .weight_group_size
180- return_int = quant_config .return_int
181- double_quant_dtype = quant_config .double_quant_dtype
182- double_quant_num_bits = quant_config .double_quant_bits
183- double_quant_scheme = "sym" if quant_config .double_quant_sym else "asym"
184- double_quant_group_size = quant_config .double_quant_group_size
185- return rtn_quantize (
186- module ,
187- num_bits ,
188- group_size ,
189- scheme ,
190- return_int = return_int ,
191- dtype = dtype ,
192- enable_full_range = enable_full_range ,
193- enable_mse_search = enable_mse_search ,
194- group_dim = group_dim ,
195- double_quant_dtype = double_quant_dtype ,
196- double_quant_scheme = double_quant_scheme ,
197- double_quant_num_bits = double_quant_num_bits ,
198- double_quant_group_size = double_quant_group_size ,
199- )
0 commit comments