Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 114 additions & 1 deletion docs/source/en/quantization/torchao.md
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ output_text = tokenizer.batch_decode(
print(output_text)
```

#### 2. Quantizing different layers with different quantization configs
#### 2. Quantizing different layers with different quantization configs (no regex)

```py
import torch
Expand Down Expand Up @@ -484,6 +484,119 @@ output_text = tokenizer.batch_decode(
print(output_text)
```

#### 3. Quantizing different layers with different quantization configs (with regex)
We can also use regex to specify the config for all modules that has `module_fqn` that
matches the regex, all regex should start with `re:`, for example `re:layers\..*\.gate_proj` will
match all layers like `layers.0.gate_proj`. See [here](https://github.com/pytorch/ao/blob/2fe0ca0899c730c528efdbec8886feaa38879f39/torchao/quantization/quant_api.py#L2392) for docs.

```py
import logging

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig

# Configure logging to see warnings and debug information
logging.basicConfig(
level=logging.INFO, format="%(name)s - %(levelname)s - %(message)s"
)

# Enable specific loggers that might contain the serialization warnings
logging.getLogger("transformers").setLevel(logging.INFO)
logging.getLogger("torchao").setLevel(logging.INFO)
logging.getLogger("safetensors").setLevel(logging.INFO)
logging.getLogger("huggingface_hub").setLevel(logging.INFO)

model_id = "facebook/opt-125m"

from torchao.quantization import (
Float8DynamicActivationFloat8WeightConfig,
Int4WeightOnlyConfig,
IntxWeightOnlyConfig,
PerRow,
PerAxis,
ModuleFqnToConfig,
Float8Tensor,
Int4TilePackedTo4dTensor,
IntxUnpackedToInt8Tensor,
)

float8dyn = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
int4wo = Int4WeightOnlyConfig(int4_packing_format="tile_packed_to_4d")
intxwo = IntxWeightOnlyConfig(weight_dtype=torch.int8, granularity=PerAxis(0))

qconfig_dict = {
# highest priority
"model.decoder.layers.3.self_attn.q_proj": int4wo,
"model.decoder.layers.3.self_attn.k_proj": int4wo,
"model.decoder.layers.3.self_attn.v_proj": int4wo,
# vllm
"model.decoder.layers.3.self_attn.qkv_proj": int4wo,

"re:model\.decoder\.layers\..+\.self_attn\.q_proj": float8dyn,
"re:model\.decoder\.layers\..+\.self_attn\.k_proj": float8dyn,
"re:model\.decoder\.layers\..+\.self_attn\.v_proj": float8dyn,
# this should not take effect and we'll fallback to _default
# since no full mach (missing `j` in the end)
"re:model\.decoder\.layers\..+\.self_attn\.out_pro": float8dyn,
# vllm
"re:model\.decoder\.layers\..+\.self_attn\.qkv_proj": float8dyn,

"_default": intxwo,
}
quant_config = ModuleFqnToConfig(qconfig_dict)
quantization_config = TorchAoConfig(quant_type=quant_config)
quantized_model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
quantization_config=quantization_config,
)
print("quantized model:", quantized_model)
tokenizer = AutoTokenizer.from_pretrained(model_id)
for i in range(12):
if i == 3:
assert isinstance(quantized_model.model.decoder.layers[i].self_attn.q_proj.weight, Int4TilePackedTo4dTensor)
assert isinstance(quantized_model.model.decoder.layers[i].self_attn.k_proj.weight, Int4TilePackedTo4dTensor)
assert isinstance(quantized_model.model.decoder.layers[i].self_attn.v_proj.weight, Int4TilePackedTo4dTensor)
else:
assert isinstance(quantized_model.model.decoder.layers[i].self_attn.q_proj.weight, Float8Tensor)
assert isinstance(quantized_model.model.decoder.layers[i].self_attn.k_proj.weight, Float8Tensor)
assert isinstance(quantized_model.model.decoder.layers[i].self_attn.v_proj.weight, Float8Tensor)
assert isinstance(quantized_model.model.decoder.layers[i].self_attn.out_proj.weight, IntxUnpackedToInt8Tensor)

# Manual Testing
prompt = "What are we having for dinner?"
print("Prompt:", prompt)
inputs = tokenizer(
prompt,
return_tensors="pt",
).to("cuda")
# setting temperature to 0 to make sure result deterministic
generated_ids = quantized_model.generate(**inputs, max_new_tokens=128, temperature=0)

correct_output_text = tokenizer.batch_decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print("Response:", correct_output_text[0][len(prompt) :])


# Load model from saved checkpoint
reloaded_model = AutoModelForCausalLM.from_pretrained(
save_to,
device_map="cuda:0",
torch_dtype=torch.bfloat16,
# quantization_config=quantization_config,
)

generated_ids = reloaded_model.generate(**inputs, max_new_tokens=128, temperature=0)
output_text = tokenizer.batch_decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print("Response:", output_text[0][len(prompt) :])

assert(correct_output_text == output_text)
```

### Autoquant

If you want to automatically choose a quantization type for quantizable layers (`nn.Linear`) you can use the [autoquant](https://pytorch.org/ao/stable/generated/torchao.quantization.autoquant.html#torchao.quantization.autoquant) API.
Expand Down
14 changes: 13 additions & 1 deletion src/transformers/quantizers/quantizer_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,9 +328,21 @@ def create_quantized_param(
module_fqn, _ = param_name.rsplit(".", 1)
c = None
if module_fqn in config.module_fqn_to_config:
assert not module_fqn.startswith("re:"), (
"module fqn should not start with`re:`, which is used for specifying regex"
)
c = config.module_fqn_to_config[module_fqn]
else:
c = config.module_fqn_to_config.get("_default", None)
for maybe_module_fqn_pattern in config.module_fqn_to_config:
if not maybe_module_fqn_pattern.startswith("re:"):
continue
elif re.fullmatch(maybe_module_fqn_pattern[3:], module_fqn):
# we'll apply the config for first fully matched pattern
c = config.module_fqn_to_config[maybe_module_fqn_pattern]
break
else:
c = config.module_fqn_to_config.get("_default", None)

if c is not None:
# filter_fn: not filtering out any modules
quantize_(module, c, filter_fn=lambda x, fqn: True)
Expand Down
95 changes: 95 additions & 0 deletions tests/quantization/torchao_integration/test_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
TensorCoreTiledLayout,
)
from torchao.quantization import (
Float8Tensor,
Float8WeightOnlyConfig,
Int8WeightOnlyConfig,
IntxWeightOnlyConfig,
MappingType,
Expand Down Expand Up @@ -277,6 +279,99 @@ def test_per_module_config_skip(self):
]
self.assertTrue(tokenizer.decode(output[0], skip_special_tokens=True) in EXPECTED_OUTPUT)

@require_torchao_version_greater_or_equal("0.13.0")
def test_module_fqn_to_config_regex_basic(self):
linear_config = Int8WeightOnlyConfig()
config = ModuleFqnToConfig({"_default": linear_config, r"re:model\.layers\..+\.self_attn\.q_proj": None})
quant_config = TorchAoConfig(quant_type=config)
quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name,
device_map=self.device,
quantization_config=quant_config,
)
# making sure `model.layers.0.self_attn.q_proj` is skipped
self.assertTrue(not isinstance(quantized_model.model.layers[0].self_attn.q_proj.weight, AffineQuantizedTensor))
tokenizer = AutoTokenizer.from_pretrained(self.model_name)

input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device)

output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
EXPECTED_OUTPUT = [
"What are we having for dinner?\n\nJessica: (smiling)",
"What are we having for dinner?\n\nJess: (smiling) I",
]
self.assertTrue(tokenizer.decode(output[0], skip_special_tokens=True) in EXPECTED_OUTPUT)

@require_torchao_version_greater_or_equal("0.13.0")
def test_module_fqn_to_config_regex_fullmatch(self):
"""Testing that we will only match the fqns that fully
matches the regex
"""
linear1_config = Int8WeightOnlyConfig()
linear2_config = Float8WeightOnlyConfig()
# intentially removing `j` after `q_proj` so it's not a full match
config = ModuleFqnToConfig(
{
r"re:model\.layers\.+\.self_attn\.q_pro": linear1_config,
"model.layers.3.self_attn.q_proj": linear2_config,
}
)
quant_config = TorchAoConfig(quant_type=config)
quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name,
device_map=self.device,
quantization_config=quant_config,
)
# highest precedence is fully specified module fqn
self.assertTrue(isinstance(quantized_model.model.layers[3].self_attn.q_proj.weight, Float8Tensor))
# because regex `model\.layers\.+*\.self_attn\.q_pro` didin't fully match `model.layers.1.self_attn.q_proj` (missing last `j`)
# this layer is not expected to be quantized to int8
self.assertTrue(not isinstance(quantized_model.model.layers[1].self_attn.q_proj.weight, AffineQuantizedTensor))
tokenizer = AutoTokenizer.from_pretrained(self.model_name)

input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device)

output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
EXPECTED_OUTPUT = [
"What are we having for dinner?\n\nJessica: (smiling)",
"What are we having for dinner?\n\nJess: (smiling) I",
]
self.assertTrue(tokenizer.decode(output[0], skip_special_tokens=True) in EXPECTED_OUTPUT)

@require_torchao_version_greater_or_equal("0.13.0")
def test_module_fqn_to_config_regex_precedence(self):
linear1_config = Int8WeightOnlyConfig()
linear2_config = Float8WeightOnlyConfig()
config = ModuleFqnToConfig(
{
r"re:model\.layers\..+\.self_attn\.q_proj": None,
"model.layers.3.self_attn.q_proj": linear2_config,
"_default": linear1_config,
}
)
quant_config = TorchAoConfig(quant_type=config)
quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name,
device_map=self.device,
quantization_config=quant_config,
)
# highest precedence is fully specified module fqn
self.assertTrue(isinstance(quantized_model.model.layers[3].self_attn.q_proj.weight, Float8Tensor))
# second precedence: regex
self.assertTrue(not isinstance(quantized_model.model.layers[1].self_attn.q_proj.weight, AffineQuantizedTensor))
# last precedence: _default
self.assertTrue(isinstance(quantized_model.model.layers[1].self_attn.k_proj.weight, AffineQuantizedTensor))
tokenizer = AutoTokenizer.from_pretrained(self.model_name)

input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device)

output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
EXPECTED_OUTPUT = [
"What are we having for dinner?\n\nJessica: (smiling)",
"What are we having for dinner?\n\nJess: (smiling) I",
]
self.assertTrue(tokenizer.decode(output[0], skip_special_tokens=True) in EXPECTED_OUTPUT)


@require_torch_accelerator
class TorchAoAcceleratorTest(TorchAoTest):
Expand Down