Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Scheme UX for QuantizationModifier #9

Merged
merged 3 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 1 addition & 12 deletions examples/quantization/llama7b_fp8_quantization.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@
import torch
from compressed_tensors.quantization import (
QuantizationArgs,
QuantizationScheme,
QuantizationType,
)
from datasets import load_dataset
from transformers import AutoTokenizer

Expand All @@ -27,13 +22,7 @@ def preprocess(batch):
ds = load_dataset("mgoin/ultrachat_2k", split="train_sft")
examples = ds.map(preprocess, remove_columns=ds.column_names)

quant_args = QuantizationArgs(type=QuantizationType.FLOAT)
quant_scheme = QuantizationScheme(
weights=quant_args, input_activations=quant_args, targets=["Linear"]
)
recipe = QuantizationModifier(
config_groups={"group_0": quant_scheme}, ignore=["lm_head"]
)
recipe = QuantizationModifier(targets="Linear", scheme="FP8")

model = SparseAutoModelForCausalLM.from_pretrained(
model_stub, torch_dtype=torch.bfloat16, device_map="auto"
Expand Down
35 changes: 3 additions & 32 deletions src/llmcompressor/modifiers/quantization/gptq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,7 @@
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

import torch
from compressed_tensors.quantization import (
QuantizationScheme,
is_preset_scheme,
preset_name_to_scheme,
)
from compressed_tensors.quantization import QuantizationScheme
from pydantic import Field
from torch.nn import Module

Expand Down Expand Up @@ -288,6 +284,8 @@ def _build_quant_modifier(self):

quantization_args_names = [
"config_groups",
"targets",
"scheme",
"num_calibration_steps",
"ignore",
"disable_quantization_observer_epoch",
Expand All @@ -299,33 +297,6 @@ def _build_quant_modifier(self):
if getattr(self, key, False)
}

if isinstance(self.targets, str):
self.targets = [self.targets]

if self.scheme is not None:
# takes precedence over config_groups

if isinstance(self.scheme, str) and is_preset_scheme(self.scheme):
# attach targets to scheme
self.scheme = {self.scheme: self.targets}

quant_args["config_groups"] = {}
for idx, key in enumerate(self.scheme.keys()):
if is_preset_scheme(key):
scheme = preset_name_to_scheme(key, self.scheme[key])
else:
scheme = QuantizationScheme.model_validate(
{"targets": self.scheme[key], **self.scheme}
)

group_name = f"group_{idx}"
quant_args["config_groups"][group_name] = scheme

if "config_groups" not in quant_args or len("config_groups") == 0:
default_quant_scheme = QuantizationScheme.default_scheme(
targets=self.targets
)
quant_args["config_groups"] = {"group_0": default_quant_scheme}
_LOGGER.info(f"Building quantization modifier with args: {quant_args}")
vllm_quant_config = {"QuantizationModifier": quant_args}
self._build_quant_modifier_from_dict(vllm_quant_config)
Expand Down
42 changes: 40 additions & 2 deletions src/llmcompressor/modifiers/quantization/quantization/base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import logging
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

from compressed_tensors.quantization import (
QuantizationConfig,
QuantizationScheme,
QuantizationStatus,
apply_quantization_config,
freeze_module_quantization,
is_preset_scheme,
preset_name_to_scheme,
set_module_for_calibration,
)
from pydantic import Field
Expand All @@ -33,15 +35,23 @@ class QuantizationModifier(Modifier):
modules. Modules not matching a scheme target will NOT be quantized.
:param ignore: optional list of module class names or submodule names to not
quantize even if they match a target in config_groups. Defaults to empty list.
:param scheme: a single quantization scheme to apply to the model. This is a
dictionary that supports all keys from QuantizationScheme except targets, which
will be set to the targets parameter set at the modifier level. Can also be set
to a dictionary of the format `preset_scheme_name: targets` for example:
`W8A8: ['Linear']` for weight and activation 8-bit.
:param targets: list of layer names to quantize if a scheme is provided
:param disable_quantization_observer_epoch: Epoch to disable updates to the module
quantization observers. At this point, quantized weights and zero points will
not be updated. Leave None to not disable observers during QAT. Default is None
:param num_calibration_steps: Number of steps to run post training calibration for.
When None, the entire calibration_dataloader is used
"""

config_groups: Dict[str, QuantizationScheme]
config_groups: Optional[Dict[str, QuantizationScheme]] = None
ignore: List[str] = Field(default_factory=list)
targets: Union[str, List[str], None] = None
scheme: Optional[Union[str, Dict[str, Any]]] = None
disable_quantization_observer_epoch: Optional[float] = None
num_calibration_steps: Optional[int] = None

Expand Down Expand Up @@ -94,6 +104,34 @@ def on_event(self, state: State, event: Event, **kwargs):
pass

def create_init_config(self) -> QuantizationConfig:
if self.targets is not None and isinstance(self.targets, str):
self.targets = [self.targets]

if self.scheme is not None:
# takes precedence over config_groups

if isinstance(self.scheme, str) and is_preset_scheme(self.scheme):
# attach targets to scheme
self.scheme = {self.scheme: self.targets}

self.config_groups = {}
for idx, key in enumerate(self.scheme.keys()):
if is_preset_scheme(key):
scheme = preset_name_to_scheme(key, self.scheme[key])
else:
scheme = QuantizationScheme.model_validate(
{"targets": self.scheme[key], **self.scheme}
)

group_name = f"group_{idx}"
self.config_groups[group_name] = scheme

if self.config_groups is None or len(self.config_groups) == 0:
default_quant_scheme = QuantizationScheme.default_scheme(
targets=self.targets
)
self.config_groups = {"group_0": default_quant_scheme}

return QuantizationConfig(
config_groups=self.config_groups,
quantization_status=QuantizationStatus.INITIALIZED,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def test_create_default_quant_modifier(self):
modifier.on_initialize_structure(testing_harness.get_state())
assert modifier.quantize
assert isinstance(modifier.quantization_modifier_, QuantizationModifier)
modifier.quantization_modifier_.create_init_config()
default_config_group_name = "group_0"
should_be_default_quant_scheme = modifier.quantization_modifier_.config_groups[
default_config_group_name
Expand Down
Loading