diff --git a/docs/source/en/quantization/torchao.md b/docs/source/en/quantization/torchao.md index 8778f9f3e5ea..dae7123999d6 100644 --- a/docs/source/en/quantization/torchao.md +++ b/docs/source/en/quantization/torchao.md @@ -445,7 +445,7 @@ output_text = tokenizer.batch_decode( print(output_text) ``` -#### 2. Quantizing different layers with different quantization configs +#### 2. Quantizing different layers with different quantization configs (no regex) ```py import torch @@ -484,6 +484,119 @@ output_text = tokenizer.batch_decode( print(output_text) ``` +#### 3. Quantizing different layers with different quantization configs (with regex) +We can also use regex to specify the config for all modules that has `module_fqn` that +matches the regex, all regex should start with `re:`, for example `re:layers\..*\.gate_proj` will +match all layers like `layers.0.gate_proj`. See [here](https://github.com/pytorch/ao/blob/2fe0ca0899c730c528efdbec8886feaa38879f39/torchao/quantization/quant_api.py#L2392) for docs. + +```py +import logging + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig + +# Configure logging to see warnings and debug information +logging.basicConfig( + level=logging.INFO, format="%(name)s - %(levelname)s - %(message)s" +) + +# Enable specific loggers that might contain the serialization warnings +logging.getLogger("transformers").setLevel(logging.INFO) +logging.getLogger("torchao").setLevel(logging.INFO) +logging.getLogger("safetensors").setLevel(logging.INFO) +logging.getLogger("huggingface_hub").setLevel(logging.INFO) + +model_id = "facebook/opt-125m" + +from torchao.quantization import ( + Float8DynamicActivationFloat8WeightConfig, + Int4WeightOnlyConfig, + IntxWeightOnlyConfig, + PerRow, + PerAxis, + ModuleFqnToConfig, + Float8Tensor, + Int4TilePackedTo4dTensor, + IntxUnpackedToInt8Tensor, +) + +float8dyn = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) +int4wo = Int4WeightOnlyConfig(int4_packing_format="tile_packed_to_4d") +intxwo = IntxWeightOnlyConfig(weight_dtype=torch.int8, granularity=PerAxis(0)) + +qconfig_dict = { + # highest priority + "model.decoder.layers.3.self_attn.q_proj": int4wo, + "model.decoder.layers.3.self_attn.k_proj": int4wo, + "model.decoder.layers.3.self_attn.v_proj": int4wo, + # vllm + "model.decoder.layers.3.self_attn.qkv_proj": int4wo, + + "re:model\.decoder\.layers\..+\.self_attn\.q_proj": float8dyn, + "re:model\.decoder\.layers\..+\.self_attn\.k_proj": float8dyn, + "re:model\.decoder\.layers\..+\.self_attn\.v_proj": float8dyn, + # this should not take effect and we'll fallback to _default + # since no full mach (missing `j` in the end) + "re:model\.decoder\.layers\..+\.self_attn\.out_pro": float8dyn, + # vllm + "re:model\.decoder\.layers\..+\.self_attn\.qkv_proj": float8dyn, + + "_default": intxwo, +} +quant_config = ModuleFqnToConfig(qconfig_dict) +quantization_config = TorchAoConfig(quant_type=quant_config) +quantized_model = AutoModelForCausalLM.from_pretrained( + model_id, + device_map="auto", + torch_dtype=torch.bfloat16, + quantization_config=quantization_config, +) +print("quantized model:", quantized_model) +tokenizer = AutoTokenizer.from_pretrained(model_id) +for i in range(12): + if i == 3: + assert isinstance(quantized_model.model.decoder.layers[i].self_attn.q_proj.weight, Int4TilePackedTo4dTensor) + assert isinstance(quantized_model.model.decoder.layers[i].self_attn.k_proj.weight, Int4TilePackedTo4dTensor) + assert isinstance(quantized_model.model.decoder.layers[i].self_attn.v_proj.weight, Int4TilePackedTo4dTensor) + else: + assert isinstance(quantized_model.model.decoder.layers[i].self_attn.q_proj.weight, Float8Tensor) + assert isinstance(quantized_model.model.decoder.layers[i].self_attn.k_proj.weight, Float8Tensor) + assert isinstance(quantized_model.model.decoder.layers[i].self_attn.v_proj.weight, Float8Tensor) + assert isinstance(quantized_model.model.decoder.layers[i].self_attn.out_proj.weight, IntxUnpackedToInt8Tensor) + +# Manual Testing +prompt = "What are we having for dinner?" +print("Prompt:", prompt) +inputs = tokenizer( + prompt, + return_tensors="pt", +).to("cuda") +# setting temperature to 0 to make sure result deterministic +generated_ids = quantized_model.generate(**inputs, max_new_tokens=128, temperature=0) + +correct_output_text = tokenizer.batch_decode( + generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False +) +print("Response:", correct_output_text[0][len(prompt) :]) + + +# Load model from saved checkpoint +reloaded_model = AutoModelForCausalLM.from_pretrained( + save_to, + device_map="cuda:0", + torch_dtype=torch.bfloat16, + # quantization_config=quantization_config, +) + +generated_ids = reloaded_model.generate(**inputs, max_new_tokens=128, temperature=0) +output_text = tokenizer.batch_decode( + generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False +) +print("Response:", output_text[0][len(prompt) :]) + +assert(correct_output_text == output_text) +``` + ### Autoquant If you want to automatically choose a quantization type for quantizable layers (`nn.Linear`) you can use the [autoquant](https://pytorch.org/ao/stable/generated/torchao.quantization.autoquant.html#torchao.quantization.autoquant) API. diff --git a/src/transformers/quantizers/quantizer_torchao.py b/src/transformers/quantizers/quantizer_torchao.py index 0b9924f667b3..44f0be43da80 100644 --- a/src/transformers/quantizers/quantizer_torchao.py +++ b/src/transformers/quantizers/quantizer_torchao.py @@ -328,9 +328,21 @@ def create_quantized_param( module_fqn, _ = param_name.rsplit(".", 1) c = None if module_fqn in config.module_fqn_to_config: + assert not module_fqn.startswith("re:"), ( + "module fqn should not start with`re:`, which is used for specifying regex" + ) c = config.module_fqn_to_config[module_fqn] else: - c = config.module_fqn_to_config.get("_default", None) + for maybe_module_fqn_pattern in config.module_fqn_to_config: + if not maybe_module_fqn_pattern.startswith("re:"): + continue + elif re.fullmatch(maybe_module_fqn_pattern[3:], module_fqn): + # we'll apply the config for first fully matched pattern + c = config.module_fqn_to_config[maybe_module_fqn_pattern] + break + else: + c = config.module_fqn_to_config.get("_default", None) + if c is not None: # filter_fn: not filtering out any modules quantize_(module, c, filter_fn=lambda x, fqn: True) diff --git a/tests/quantization/torchao_integration/test_torchao.py b/tests/quantization/torchao_integration/test_torchao.py index 1ddc2de0801f..896e999d7666 100644 --- a/tests/quantization/torchao_integration/test_torchao.py +++ b/tests/quantization/torchao_integration/test_torchao.py @@ -46,6 +46,8 @@ TensorCoreTiledLayout, ) from torchao.quantization import ( + Float8Tensor, + Float8WeightOnlyConfig, Int8WeightOnlyConfig, IntxWeightOnlyConfig, MappingType, @@ -277,6 +279,99 @@ def test_per_module_config_skip(self): ] self.assertTrue(tokenizer.decode(output[0], skip_special_tokens=True) in EXPECTED_OUTPUT) + @require_torchao_version_greater_or_equal("0.13.0") + def test_module_fqn_to_config_regex_basic(self): + linear_config = Int8WeightOnlyConfig() + config = ModuleFqnToConfig({"_default": linear_config, r"re:model\.layers\..+\.self_attn\.q_proj": None}) + quant_config = TorchAoConfig(quant_type=config) + quantized_model = AutoModelForCausalLM.from_pretrained( + self.model_name, + device_map=self.device, + quantization_config=quant_config, + ) + # making sure `model.layers.0.self_attn.q_proj` is skipped + self.assertTrue(not isinstance(quantized_model.model.layers[0].self_attn.q_proj.weight, AffineQuantizedTensor)) + tokenizer = AutoTokenizer.from_pretrained(self.model_name) + + input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device) + + output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) + EXPECTED_OUTPUT = [ + "What are we having for dinner?\n\nJessica: (smiling)", + "What are we having for dinner?\n\nJess: (smiling) I", + ] + self.assertTrue(tokenizer.decode(output[0], skip_special_tokens=True) in EXPECTED_OUTPUT) + + @require_torchao_version_greater_or_equal("0.13.0") + def test_module_fqn_to_config_regex_fullmatch(self): + """Testing that we will only match the fqns that fully + matches the regex + """ + linear1_config = Int8WeightOnlyConfig() + linear2_config = Float8WeightOnlyConfig() + # intentially removing `j` after `q_proj` so it's not a full match + config = ModuleFqnToConfig( + { + r"re:model\.layers\.+\.self_attn\.q_pro": linear1_config, + "model.layers.3.self_attn.q_proj": linear2_config, + } + ) + quant_config = TorchAoConfig(quant_type=config) + quantized_model = AutoModelForCausalLM.from_pretrained( + self.model_name, + device_map=self.device, + quantization_config=quant_config, + ) + # highest precedence is fully specified module fqn + self.assertTrue(isinstance(quantized_model.model.layers[3].self_attn.q_proj.weight, Float8Tensor)) + # because regex `model\.layers\.+*\.self_attn\.q_pro` didin't fully match `model.layers.1.self_attn.q_proj` (missing last `j`) + # this layer is not expected to be quantized to int8 + self.assertTrue(not isinstance(quantized_model.model.layers[1].self_attn.q_proj.weight, AffineQuantizedTensor)) + tokenizer = AutoTokenizer.from_pretrained(self.model_name) + + input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device) + + output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) + EXPECTED_OUTPUT = [ + "What are we having for dinner?\n\nJessica: (smiling)", + "What are we having for dinner?\n\nJess: (smiling) I", + ] + self.assertTrue(tokenizer.decode(output[0], skip_special_tokens=True) in EXPECTED_OUTPUT) + + @require_torchao_version_greater_or_equal("0.13.0") + def test_module_fqn_to_config_regex_precedence(self): + linear1_config = Int8WeightOnlyConfig() + linear2_config = Float8WeightOnlyConfig() + config = ModuleFqnToConfig( + { + r"re:model\.layers\..+\.self_attn\.q_proj": None, + "model.layers.3.self_attn.q_proj": linear2_config, + "_default": linear1_config, + } + ) + quant_config = TorchAoConfig(quant_type=config) + quantized_model = AutoModelForCausalLM.from_pretrained( + self.model_name, + device_map=self.device, + quantization_config=quant_config, + ) + # highest precedence is fully specified module fqn + self.assertTrue(isinstance(quantized_model.model.layers[3].self_attn.q_proj.weight, Float8Tensor)) + # second precedence: regex + self.assertTrue(not isinstance(quantized_model.model.layers[1].self_attn.q_proj.weight, AffineQuantizedTensor)) + # last precedence: _default + self.assertTrue(isinstance(quantized_model.model.layers[1].self_attn.k_proj.weight, AffineQuantizedTensor)) + tokenizer = AutoTokenizer.from_pretrained(self.model_name) + + input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device) + + output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) + EXPECTED_OUTPUT = [ + "What are we having for dinner?\n\nJessica: (smiling)", + "What are we having for dinner?\n\nJess: (smiling) I", + ] + self.assertTrue(tokenizer.decode(output[0], skip_special_tokens=True) in EXPECTED_OUTPUT) + @require_torch_accelerator class TorchAoAcceleratorTest(TorchAoTest):