Skip to content

Commit 8111e5b

Browse files
committed
add assert for
1 parent d302766 commit 8111e5b

File tree

2 files changed

+117
-3
lines changed

2 files changed

+117
-3
lines changed

docs/source/en/quantization/torchao.md

Lines changed: 114 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,7 @@ output_text = tokenizer.batch_decode(
445445
print(output_text)
446446
```
447447

448-
#### 2. Quantizing different layers with different quantization configs
448+
#### 2. Quantizing different layers with different quantization configs (no regex)
449449

450450
```py
451451
import torch
@@ -484,6 +484,119 @@ output_text = tokenizer.batch_decode(
484484
print(output_text)
485485
```
486486

487+
#### 3. Quantizing different layers with different quantization configs (with regex)
488+
We can also use regex to specify the config for all modules that has `module_fqn` that
489+
matches the regex, all regex should start with `re:`, for example `re:layers\..*\.gate_proj` will
490+
match all layers like `layers.0.gate_proj`. See [here](https://github.com/pytorch/ao/blob/2fe0ca0899c730c528efdbec8886feaa38879f39/torchao/quantization/quant_api.py#L2392) for docs.
491+
492+
```py
493+
import logging
494+
495+
import torch
496+
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
497+
498+
# Configure logging to see warnings and debug information
499+
logging.basicConfig(
500+
level=logging.INFO, format="%(name)s - %(levelname)s - %(message)s"
501+
)
502+
503+
# Enable specific loggers that might contain the serialization warnings
504+
logging.getLogger("transformers").setLevel(logging.INFO)
505+
logging.getLogger("torchao").setLevel(logging.INFO)
506+
logging.getLogger("safetensors").setLevel(logging.INFO)
507+
logging.getLogger("huggingface_hub").setLevel(logging.INFO)
508+
509+
model_id = "facebook/opt-125m"
510+
511+
from torchao.quantization import (
512+
Float8DynamicActivationFloat8WeightConfig,
513+
Int4WeightOnlyConfig,
514+
IntxWeightOnlyConfig,
515+
PerRow,
516+
PerAxis,
517+
ModuleFqnToConfig,
518+
Float8Tensor,
519+
Int4TilePackedTo4dTensor,
520+
IntxUnpackedToInt8Tensor,
521+
)
522+
523+
float8dyn = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
524+
int4wo = Int4WeightOnlyConfig(int4_packing_format="tile_packed_to_4d")
525+
intxwo = IntxWeightOnlyConfig(weight_dtype=torch.int8, granularity=PerAxis(0))
526+
527+
qconfig_dict = {
528+
# highest priority
529+
"model.decoder.layers.3.self_attn.q_proj": int4wo,
530+
"model.decoder.layers.3.self_attn.k_proj": int4wo,
531+
"model.decoder.layers.3.self_attn.v_proj": int4wo,
532+
# vllm
533+
"model.decoder.layers.3.self_attn.qkv_proj": int4wo,
534+
535+
"re:model\.decoder\.layers\..+\.self_attn\.q_proj": float8dyn,
536+
"re:model\.decoder\.layers\..+\.self_attn\.k_proj": float8dyn,
537+
"re:model\.decoder\.layers\..+\.self_attn\.v_proj": float8dyn,
538+
# this should not take effect and we'll fallback to _default
539+
# since no full mach (missing `j` in the end)
540+
"re:model\.decoder\.layers\..+\.self_attn\.out_pro": float8dyn,
541+
# vllm
542+
"re:model\.decoder\.layers\..+\.self_attn\.qkv_proj": float8dyn,
543+
544+
"_default": intxwo,
545+
}
546+
quant_config = ModuleFqnToConfig(qconfig_dict)
547+
quantization_config = TorchAoConfig(quant_type=quant_config)
548+
quantized_model = AutoModelForCausalLM.from_pretrained(
549+
model_id,
550+
device_map="auto",
551+
torch_dtype=torch.bfloat16,
552+
quantization_config=quantization_config,
553+
)
554+
print("quantized model:", quantized_model)
555+
tokenizer = AutoTokenizer.from_pretrained(model_id)
556+
for i in range(12):
557+
if i == 3:
558+
assert isinstance(quantized_model.model.decoder.layers[i].self_attn.q_proj.weight, Int4TilePackedTo4dTensor)
559+
assert isinstance(quantized_model.model.decoder.layers[i].self_attn.k_proj.weight, Int4TilePackedTo4dTensor)
560+
assert isinstance(quantized_model.model.decoder.layers[i].self_attn.v_proj.weight, Int4TilePackedTo4dTensor)
561+
else:
562+
assert isinstance(quantized_model.model.decoder.layers[i].self_attn.q_proj.weight, Float8Tensor)
563+
assert isinstance(quantized_model.model.decoder.layers[i].self_attn.k_proj.weight, Float8Tensor)
564+
assert isinstance(quantized_model.model.decoder.layers[i].self_attn.v_proj.weight, Float8Tensor)
565+
assert isinstance(quantized_model.model.decoder.layers[i].self_attn.out_proj.weight, IntxUnpackedToInt8Tensor)
566+
567+
# Manual Testing
568+
prompt = "What are we having for dinner?"
569+
print("Prompt:", prompt)
570+
inputs = tokenizer(
571+
prompt,
572+
return_tensors="pt",
573+
).to("cuda")
574+
# setting temperature to 0 to make sure result deterministic
575+
generated_ids = quantized_model.generate(**inputs, max_new_tokens=128, temperature=0)
576+
577+
correct_output_text = tokenizer.batch_decode(
578+
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
579+
)
580+
print("Response:", correct_output_text[0][len(prompt) :])
581+
582+
583+
# Load model from saved checkpoint
584+
reloaded_model = AutoModelForCausalLM.from_pretrained(
585+
save_to,
586+
device_map="cuda:0",
587+
torch_dtype=torch.bfloat16,
588+
# quantization_config=quantization_config,
589+
)
590+
591+
generated_ids = reloaded_model.generate(**inputs, max_new_tokens=128, temperature=0)
592+
output_text = tokenizer.batch_decode(
593+
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
594+
)
595+
print("Response:", output_text[0][len(prompt) :])
596+
597+
assert(correct_output_text == output_text)
598+
```
599+
487600
### Autoquant
488601

489602
If you want to automatically choose a quantization type for quantizable layers (`nn.Linear`) you can use the [autoquant](https://pytorch.org/ao/stable/generated/torchao.quantization.autoquant.html#torchao.quantization.autoquant) API.

src/transformers/quantizers/quantizer_torchao.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,15 +297,16 @@ def create_quantized_param(
297297

298298
# handle ModuleFqnToConfig, introduced in torchao 0.12.0+
299299
if self.quantization_config._get_ao_version() >= version.Version("0.12.0"):
300-
import re
301-
302300
from torchao.quantization import ModuleFqnToConfig
303301

304302
config = self.quantization_config.get_apply_tensor_subclass()
305303
if isinstance(config, ModuleFqnToConfig):
306304
module_fqn, _ = param_name.rsplit(".", 1)
307305
c = None
308306
if module_fqn in config.module_fqn_to_config:
307+
assert not module_fqn.startswith("re:"), (
308+
"module fqn should not start with`re:`, which is used for specifying regex"
309+
)
309310
c = config.module_fqn_to_config[module_fqn]
310311
else:
311312
for maybe_module_fqn_pattern in config.module_fqn_to_config:

0 commit comments

Comments
 (0)