@@ -445,7 +445,7 @@ output_text = tokenizer.batch_decode(
445445print (output_text)
446446```
447447
448- #### 2. Quantizing different layers with different quantization configs
448+ #### 2. Quantizing different layers with different quantization configs (no regex)
449449
450450``` py
451451import torch
@@ -484,6 +484,119 @@ output_text = tokenizer.batch_decode(
484484print (output_text)
485485```
486486
487+ #### 3. Quantizing different layers with different quantization configs (with regex)
488+ We can also use regex to specify the config for all modules that has ` module_fqn ` that
489+ matches the regex, all regex should start with ` re: ` , for example ` re:layers\..*\.gate_proj ` will
490+ match all layers like ` layers.0.gate_proj ` . See [ here] ( https://github.com/pytorch/ao/blob/2fe0ca0899c730c528efdbec8886feaa38879f39/torchao/quantization/quant_api.py#L2392 ) for docs.
491+
492+ ``` py
493+ import logging
494+
495+ import torch
496+ from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
497+
498+ # Configure logging to see warnings and debug information
499+ logging.basicConfig(
500+ level = logging.INFO , format = " %(name)s - %(levelname)s - %(message)s "
501+ )
502+
503+ # Enable specific loggers that might contain the serialization warnings
504+ logging.getLogger(" transformers" ).setLevel(logging.INFO )
505+ logging.getLogger(" torchao" ).setLevel(logging.INFO )
506+ logging.getLogger(" safetensors" ).setLevel(logging.INFO )
507+ logging.getLogger(" huggingface_hub" ).setLevel(logging.INFO )
508+
509+ model_id = " facebook/opt-125m"
510+
511+ from torchao.quantization import (
512+ Float8DynamicActivationFloat8WeightConfig,
513+ Int4WeightOnlyConfig,
514+ IntxWeightOnlyConfig,
515+ PerRow,
516+ PerAxis,
517+ ModuleFqnToConfig,
518+ Float8Tensor,
519+ Int4TilePackedTo4dTensor,
520+ IntxUnpackedToInt8Tensor,
521+ )
522+
523+ float8dyn = Float8DynamicActivationFloat8WeightConfig(granularity = PerRow())
524+ int4wo = Int4WeightOnlyConfig(int4_packing_format = " tile_packed_to_4d" )
525+ intxwo = IntxWeightOnlyConfig(weight_dtype = torch.int8, granularity = PerAxis(0 ))
526+
527+ qconfig_dict = {
528+ # highest priority
529+ " model.decoder.layers.3.self_attn.q_proj" : int4wo,
530+ " model.decoder.layers.3.self_attn.k_proj" : int4wo,
531+ " model.decoder.layers.3.self_attn.v_proj" : int4wo,
532+ # vllm
533+ " model.decoder.layers.3.self_attn.qkv_proj" : int4wo,
534+
535+ " re:model\.decoder\.layers\..+\.self_attn\.q_proj" : float8dyn,
536+ " re:model\.decoder\.layers\..+\.self_attn\.k_proj" : float8dyn,
537+ " re:model\.decoder\.layers\..+\.self_attn\.v_proj" : float8dyn,
538+ # this should not take effect and we'll fallback to _default
539+ # since no full mach (missing `j` in the end)
540+ " re:model\.decoder\.layers\..+\.self_attn\.out_pro" : float8dyn,
541+ # vllm
542+ " re:model\.decoder\.layers\..+\.self_attn\.qkv_proj" : float8dyn,
543+
544+ " _default" : intxwo,
545+ }
546+ quant_config = ModuleFqnToConfig(qconfig_dict)
547+ quantization_config = TorchAoConfig(quant_type = quant_config)
548+ quantized_model = AutoModelForCausalLM.from_pretrained(
549+ model_id,
550+ device_map = " auto" ,
551+ torch_dtype = torch.bfloat16,
552+ quantization_config = quantization_config,
553+ )
554+ print (" quantized model:" , quantized_model)
555+ tokenizer = AutoTokenizer.from_pretrained(model_id)
556+ for i in range (12 ):
557+ if i == 3 :
558+ assert isinstance (quantized_model.model.decoder.layers[i].self_attn.q_proj.weight, Int4TilePackedTo4dTensor)
559+ assert isinstance (quantized_model.model.decoder.layers[i].self_attn.k_proj.weight, Int4TilePackedTo4dTensor)
560+ assert isinstance (quantized_model.model.decoder.layers[i].self_attn.v_proj.weight, Int4TilePackedTo4dTensor)
561+ else :
562+ assert isinstance (quantized_model.model.decoder.layers[i].self_attn.q_proj.weight, Float8Tensor)
563+ assert isinstance (quantized_model.model.decoder.layers[i].self_attn.k_proj.weight, Float8Tensor)
564+ assert isinstance (quantized_model.model.decoder.layers[i].self_attn.v_proj.weight, Float8Tensor)
565+ assert isinstance (quantized_model.model.decoder.layers[i].self_attn.out_proj.weight, IntxUnpackedToInt8Tensor)
566+
567+ # Manual Testing
568+ prompt = " What are we having for dinner?"
569+ print (" Prompt:" , prompt)
570+ inputs = tokenizer(
571+ prompt,
572+ return_tensors = " pt" ,
573+ ).to(" cuda" )
574+ # setting temperature to 0 to make sure result deterministic
575+ generated_ids = quantized_model.generate(** inputs, max_new_tokens = 128 , temperature = 0 )
576+
577+ correct_output_text = tokenizer.batch_decode(
578+ generated_ids, skip_special_tokens = True , clean_up_tokenization_spaces = False
579+ )
580+ print (" Response:" , correct_output_text[0 ][len (prompt) :])
581+
582+
583+ # Load model from saved checkpoint
584+ reloaded_model = AutoModelForCausalLM.from_pretrained(
585+ save_to,
586+ device_map = " cuda:0" ,
587+ torch_dtype = torch.bfloat16,
588+ # quantization_config=quantization_config,
589+ )
590+
591+ generated_ids = reloaded_model.generate(** inputs, max_new_tokens = 128 , temperature = 0 )
592+ output_text = tokenizer.batch_decode(
593+ generated_ids, skip_special_tokens = True , clean_up_tokenization_spaces = False
594+ )
595+ print (" Response:" , output_text[0 ][len (prompt) :])
596+
597+ assert (correct_output_text == output_text)
598+ ```
599+
487600### Autoquant
488601
489602If you want to automatically choose a quantization type for quantizable layers (` nn.Linear ` ) you can use the [ autoquant] ( https://pytorch.org/ao/stable/generated/torchao.quantization.autoquant.html#torchao.quantization.autoquant ) API.
0 commit comments