1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ import torch
16+ import torchao
17+ from packaging .version import parse
1518from transformers import AutoModelForCausalLM , GenerationConfig
1619
1720from ..integrations import CausalLMExportableModule
@@ -54,12 +57,14 @@ def load_causal_lm_model(model_name_or_path: str, **kwargs) -> CausalLMExportabl
5457 cache_implementation = kwargs .get ("cache_implementation" , "static" )
5558 max_length = kwargs .get ("max_length" , 2048 )
5659 config = kwargs .get ("config" , None )
60+ quantization_recipe = kwargs .get ("quantize" , None )
5761
5862 eager_model = AutoModelForCausalLM .from_pretrained (
5963 model_name_or_path ,
6064 device_map = device ,
6165 torch_dtype = dtype ,
6266 config = config ,
67+ # quantization_config=quantization_config,
6368 attn_implementation = attn_implementation ,
6469 generation_config = GenerationConfig (
6570 use_cache = True ,
@@ -71,4 +76,25 @@ def load_causal_lm_model(model_name_or_path: str, **kwargs) -> CausalLMExportabl
7176 },
7277 ),
7378 )
79+
80+ if quantization_recipe == "8da4w" :
81+ if parse (torchao .__version__ ) < parse ("0.11.0.dev0" ):
82+ raise RuntimeError ("Quantization 8da4w requires torchao >= 0.11.0. Please upgrade torchao." )
83+
84+ from torchao .quantization .granularity import PerGroup
85+ from torchao .quantization .quant_api import (
86+ Int8DynamicActivationIntxWeightConfig ,
87+ )
88+
89+ # TODO: Should switch to TorchAoConfig once the quant issue on final lm_head layer is fixed.
90+ linear_config = Int8DynamicActivationIntxWeightConfig (
91+ weight_dtype = torch .int4 ,
92+ weight_granularity = PerGroup (128 ),
93+ )
94+
95+ torchao .quantize_ (
96+ eager_model ,
97+ linear_config ,
98+ )
99+
74100 return CausalLMExportableModule (eager_model )
0 commit comments