diff --git a/examples/quantization/llama7b_fp8_quantization.py b/examples/quantization/llama7b_fp8_quantization.py index 55dfef0cc..80a446123 100644 --- a/examples/quantization/llama7b_fp8_quantization.py +++ b/examples/quantization/llama7b_fp8_quantization.py @@ -1,9 +1,4 @@ import torch -from compressed_tensors.quantization import ( - QuantizationArgs, - QuantizationScheme, - QuantizationType, -) from datasets import load_dataset from transformers import AutoTokenizer @@ -27,13 +22,7 @@ def preprocess(batch): ds = load_dataset("mgoin/ultrachat_2k", split="train_sft") examples = ds.map(preprocess, remove_columns=ds.column_names) -quant_args = QuantizationArgs(type=QuantizationType.FLOAT) -quant_scheme = QuantizationScheme( - weights=quant_args, input_activations=quant_args, targets=["Linear"] -) -recipe = QuantizationModifier( - config_groups={"group_0": quant_scheme}, ignore=["lm_head"] -) +recipe = QuantizationModifier(targets="Linear", scheme="FP8") model = SparseAutoModelForCausalLM.from_pretrained( model_stub, torch_dtype=torch.bfloat16, device_map="auto" diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 00e3170a5..fe18623b7 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -2,11 +2,7 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch -from compressed_tensors.quantization import ( - QuantizationScheme, - is_preset_scheme, - preset_name_to_scheme, -) +from compressed_tensors.quantization import QuantizationScheme from pydantic import Field from torch.nn import Module @@ -288,6 +284,8 @@ def _build_quant_modifier(self): quantization_args_names = [ "config_groups", + "targets", + "scheme", "num_calibration_steps", "ignore", "disable_quantization_observer_epoch", @@ -299,33 +297,6 @@ def _build_quant_modifier(self): if getattr(self, key, False) } - if isinstance(self.targets, str): - self.targets = [self.targets] - - if self.scheme is not None: - # takes precedence over config_groups - - if isinstance(self.scheme, str) and is_preset_scheme(self.scheme): - # attach targets to scheme - self.scheme = {self.scheme: self.targets} - - quant_args["config_groups"] = {} - for idx, key in enumerate(self.scheme.keys()): - if is_preset_scheme(key): - scheme = preset_name_to_scheme(key, self.scheme[key]) - else: - scheme = QuantizationScheme.model_validate( - {"targets": self.scheme[key], **self.scheme} - ) - - group_name = f"group_{idx}" - quant_args["config_groups"][group_name] = scheme - - if "config_groups" not in quant_args or len("config_groups") == 0: - default_quant_scheme = QuantizationScheme.default_scheme( - targets=self.targets - ) - quant_args["config_groups"] = {"group_0": default_quant_scheme} _LOGGER.info(f"Building quantization modifier with args: {quant_args}") vllm_quant_config = {"QuantizationModifier": quant_args} self._build_quant_modifier_from_dict(vllm_quant_config) diff --git a/src/llmcompressor/modifiers/quantization/quantization/base.py b/src/llmcompressor/modifiers/quantization/quantization/base.py index 00073d7d4..e271bbc4e 100644 --- a/src/llmcompressor/modifiers/quantization/quantization/base.py +++ b/src/llmcompressor/modifiers/quantization/quantization/base.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from compressed_tensors.quantization import ( QuantizationConfig, @@ -7,6 +7,8 @@ QuantizationStatus, apply_quantization_config, freeze_module_quantization, + is_preset_scheme, + preset_name_to_scheme, set_module_for_calibration, ) from pydantic import Field @@ -33,6 +35,12 @@ class QuantizationModifier(Modifier): modules. Modules not matching a scheme target will NOT be quantized. :param ignore: optional list of module class names or submodule names to not quantize even if they match a target in config_groups. Defaults to empty list. + :param scheme: a single quantization scheme to apply to the model. This is a + dictionary that supports all keys from QuantizationScheme except targets, which + will be set to the targets parameter set at the modifier level. Can also be set + to a dictionary of the format `preset_scheme_name: targets` for example: + `W8A8: ['Linear']` for weight and activation 8-bit. + :param targets: list of layer names to quantize if a scheme is provided :param disable_quantization_observer_epoch: Epoch to disable updates to the module quantization observers. At this point, quantized weights and zero points will not be updated. Leave None to not disable observers during QAT. Default is None @@ -40,8 +48,10 @@ class QuantizationModifier(Modifier): When None, the entire calibration_dataloader is used """ - config_groups: Dict[str, QuantizationScheme] + config_groups: Optional[Dict[str, QuantizationScheme]] = None ignore: List[str] = Field(default_factory=list) + targets: Union[str, List[str], None] = None + scheme: Optional[Union[str, Dict[str, Any]]] = None disable_quantization_observer_epoch: Optional[float] = None num_calibration_steps: Optional[int] = None @@ -94,6 +104,34 @@ def on_event(self, state: State, event: Event, **kwargs): pass def create_init_config(self) -> QuantizationConfig: + if self.targets is not None and isinstance(self.targets, str): + self.targets = [self.targets] + + if self.scheme is not None: + # takes precedence over config_groups + + if isinstance(self.scheme, str) and is_preset_scheme(self.scheme): + # attach targets to scheme + self.scheme = {self.scheme: self.targets} + + self.config_groups = {} + for idx, key in enumerate(self.scheme.keys()): + if is_preset_scheme(key): + scheme = preset_name_to_scheme(key, self.scheme[key]) + else: + scheme = QuantizationScheme.model_validate( + {"targets": self.scheme[key], **self.scheme} + ) + + group_name = f"group_{idx}" + self.config_groups[group_name] = scheme + + if self.config_groups is None or len(self.config_groups) == 0: + default_quant_scheme = QuantizationScheme.default_scheme( + targets=self.targets + ) + self.config_groups = {"group_0": default_quant_scheme} + return QuantizationConfig( config_groups=self.config_groups, quantization_status=QuantizationStatus.INITIALIZED, diff --git a/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py b/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py index acc1fcdf1..02c81f11c 100644 --- a/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py +++ b/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py @@ -81,6 +81,7 @@ def test_create_default_quant_modifier(self): modifier.on_initialize_structure(testing_harness.get_state()) assert modifier.quantize assert isinstance(modifier.quantization_modifier_, QuantizationModifier) + modifier.quantization_modifier_.create_init_config() default_config_group_name = "group_0" should_be_default_quant_scheme = modifier.quantization_modifier_.config_groups[ default_config_group_name