11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+ import json
34from typing import Any , Optional
45
56import torch
@@ -40,7 +41,8 @@ class TorchAOConfig(QuantizationConfig):
4041
4142 def __init__ (self ,
4243 torchao_config ,
43- skip_modules : Optional [list [str ]] = None ) -> None :
44+ skip_modules : Optional [list [str ]] = None ,
45+ is_checkpoint_torchao_serialized : bool = False ) -> None :
4446 """
4547 # TorchAO quantization relies on tensor subclasses. In order,
4648 # to enable proper caching this needs standalone compile
@@ -58,9 +60,11 @@ def __init__(self,
5860 super ().__init__ ()
5961 self .torchao_config = torchao_config
6062 self .skip_modules = skip_modules or []
63+ self .is_checkpoint_torchao_serialized = is_checkpoint_torchao_serialized
6164
6265 def __repr__ (self ) -> str :
63- return f"TorchAOConfig({ self .torchao_config } )"
66+ return f"TorchAOConfig({ self .torchao_config = } , { self .skip_modules = } , " \
67+ f"{ self .is_checkpoint_torchao_serialized = } )"
6468
6569 def get_name (self ) -> QuantizationMethods :
6670 return "torchao"
@@ -74,7 +78,10 @@ def get_min_capability(cls) -> int:
7478
7579 @staticmethod
7680 def get_config_filenames () -> list [str ]:
77- return ["config.json" ]
81+ """torchao doesn't require additional config files, we use
82+ `config.json` from huggingface: `model_config.hf_config`
83+ """
84+ return []
7885
7986 @classmethod
8087 def from_config (cls , config : dict [str , Any ]) -> "TorchAOConfig" :
@@ -87,6 +94,10 @@ def from_config(cls, config: dict[str, Any]) -> "TorchAOConfig":
8794 "`pip install torchao>=0.10.0` to use torchao quantization."
8895 ) from err
8996
97+ quant_method = cls .get_from_keys_or (config , ["quant_method" ], None )
98+ is_checkpoint_torchao_serialized = (quant_method is not None
99+ and "torchao" in quant_method )
100+
90101 hf_config = cls .get_from_keys_or (config , ["quant_type" ], None )
91102 assert hf_config is not None , "quant_type must be specified"
92103 assert len (hf_config ) == 1 and "default" in hf_config , (
@@ -110,7 +121,38 @@ def from_config(cls, config: dict[str, Any]) -> "TorchAOConfig":
110121 if layer_cfg is None :
111122 skip_modules .append (layer )
112123
113- return cls (ao_config , skip_modules )
124+ return cls (ao_config , skip_modules , is_checkpoint_torchao_serialized )
125+
126+ @classmethod
127+ def from_config_file (cls , config_file : str ) -> "TorchAOConfig" :
128+ """Initialize class from a config file. Example:
129+ ```
130+ config = (
131+ Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
132+ )
133+ fn = "torchao_config.json"
134+
135+ with open(fn, "w") as f:
136+ f.write(json.dumps(config_to_dict(config)))
137+ ```
138+ """
139+ with open (config_file ) as f :
140+ f .seek (0 )
141+ f_read = f .read ()
142+ config_dict = json .loads (f_read )
143+
144+ hf_config = {"quant_type" : {"default" : config_dict }}
145+ return cls .from_config (hf_config )
146+
147+ @classmethod
148+ def from_config_dict_json (cls , config_dict_json : str ) -> "TorchAOConfig" :
149+ """Iniitalize class from a config_dict json string, got from
150+ torchao_config_object = some AOBaseConfig object
151+ json.dumps(config_to_dict(torchao_config_object))
152+ """
153+ config_dict = json .loads (config_dict_json )
154+ hf_config = {"quant_type" : {"default" : config_dict }}
155+ return cls .from_config (hf_config )
114156
115157 def get_quant_method (self , layer : torch .nn .Module ,
116158 prefix : str ) -> Optional ["QuantizeMethodBase" ]:
@@ -128,7 +170,9 @@ def get_quant_method(self, layer: torch.nn.Module,
128170 c = module_fqn_to_config .get (
129171 module_fqn ) or module_fqn_to_config .get ("_default" , None )
130172 if c is not None :
131- current_torchao_config = TorchAOConfig (c , self .skip_modules )
173+ current_torchao_config = TorchAOConfig (
174+ c , self .skip_modules ,
175+ self .is_checkpoint_torchao_serialized )
132176 return TorchAOLinearMethod (current_torchao_config )
133177 else :
134178 return UnquantizedLinearMethod ()
@@ -172,7 +216,7 @@ class TorchAOLinearMethod(LinearMethodBase):
172216 """Linear method for torchao.
173217
174218 Args:
175- quant_config: The torchao quantization config, a string that encodes
219+ quant_config: The torchao quantization config, a string that encodes
176220 the type of quantization and all relevant arguments.
177221 """
178222
@@ -197,8 +241,9 @@ def create_weights(
197241 ),
198242 requires_grad = False ,
199243 )
200- weight = torchao_quantize_param_data (weight ,
201- self .quant_config .torchao_config )
244+ if self .quant_config .is_checkpoint_torchao_serialized :
245+ weight = torchao_quantize_param_data (
246+ weight , self .quant_config .torchao_config )
202247
203248 set_weight_attrs (weight , {"input_dim" : 1 , "output_dim" : 0 })
204249
@@ -212,3 +257,14 @@ def apply(
212257 bias : Optional [torch .Tensor ] = None ,
213258 ) -> torch .Tensor :
214259 return F .linear (x , layer .weight , bias )
260+
261+ def process_weights_after_loading (self , layer : torch .nn .Module ) -> None :
262+ if self .quant_config .is_checkpoint_torchao_serialized :
263+ return
264+
265+ # quantize the weight on the fly if the checkpoint is not already
266+ # quantized by torchao
267+ weight = torchao_quantize_param_data (layer .weight ,
268+ self .quant_config .torchao_config )
269+ set_weight_attrs (weight , {"input_dim" : 1 , "output_dim" : 0 })
270+ layer .register_parameter ("weight" , weight )
0 commit comments