diff --git a/auto_round/__init__.py b/auto_round/__init__.py index a7679a615..268065ba4 100644 --- a/auto_round/__init__.py +++ b/auto_round/__init__.py @@ -15,7 +15,7 @@ # support for old api from auto_round.autoround import AutoRoundLLM, AutoRoundMLLM, AutoRoundAdam, AutoRoundDiffusion -from auto_round.schemes import QuantizationScheme +from auto_round.schemes import QuantizationScheme, AutoScheme from auto_round.utils import LazyImport diff --git a/auto_round/__main__.py b/auto_round/__main__.py index fb5acbd9e..d2ff36912 100644 --- a/auto_round/__main__.py +++ b/auto_round/__main__.py @@ -19,7 +19,7 @@ from auto_round.compressors import BaseCompressor from auto_round.eval.eval_cli import EvalArgumentParser, _eval_init, eval, eval_task_by_task -from auto_round.schemes import PRESET_SCHEMES +from auto_round.schemes import PRESET_SCHEMES, AutoScheme from auto_round.utils import ( clear_memory, get_device_and_parallelism, @@ -47,7 +47,15 @@ def __init__(self, *args, **kwargs): # choices=["W4A16", "W2A16", "W3A16", "W8A16", "MXFP4", "MXFP8", "NVFP4", "FPW8A16", "FP8_STATIC"], help="quantization scheme", ) - + self.add_argument("--avg_bits", default=None, type=float, help="for auto scheme, number of avg weight bits") + self.add_argument( + "--options", default=None, type=str, help="for auto scheme, options for auto scheme, e.g. 'W4A16,W8A16'" + ) + self.add_argument( + "--ignore_scale_zp_bits", + action="store_true", + help="for auto scheme whether ignore scale zp bits calculation ", + ) self.add_argument("--bits", default=None, type=int, help="number of weight bits") self.add_argument("--group_size", default=None, type=int, help="group size") self.add_argument("--asym", action="store_true", help="whether to use asym quantization") @@ -110,7 +118,7 @@ def __init__(self, *args, **kwargs): self.add_argument( "--scale_dtype", - default="fp16", + default=None, choices=["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"], help="scale data type to use for quantization", ) @@ -512,6 +520,21 @@ def tune(args): extra_config.mllm_config = mllm_config extra_config.diffusion_config = diffusion_config + layer_config = {} + # from auto_round.auto_schemes.haha import get_mixed_config_layer_config + # layer_config = {} + # best_path = get_mixed_config_layer_config(model_name, target_bits=3) + # for item in best_path: + # layer_config[item[0]] = {} + # layer_config[item[0]]["bits"] = item[1] + + if args.avg_bits is not None: + if args.options is None: + raise ValueError("please set --options for auto scheme") + scheme = AutoScheme( + options=args.options, avg_bits=args.avg_bits, ignore_scale_zp_bits=args.ignore_scale_zp_bits + ) + autoround: BaseCompressor = AutoRound( model=model_name, scheme=scheme, @@ -528,6 +551,7 @@ def tune(args): not_use_best_mse=args.not_use_best_mse, enable_adam=args.adam, extra_config=extra_config, + layer_config=layer_config, ) model_name = args.model.rstrip("/") diff --git a/auto_round/auto_schemes/__init__.py b/auto_round/auto_schemes/__init__.py new file mode 100644 index 000000000..e0e5ccb66 --- /dev/null +++ b/auto_round/auto_schemes/__init__.py @@ -0,0 +1,42 @@ +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +AUTO_SCHEMES_METHODS = {} + + +def register_scheme_methods(names): + """Class decorator to register a mixed precision algorithm to the registry. + + Decorator function used before a Pattern subclass. + + Args: + names: A string. Define the export type. + + Returns: + cls: The class of register. + """ + + def register(alg): + if isinstance(names, (tuple, list)): + for name in names: + AUTO_SCHEMES_METHODS[name] = alg + else: + AUTO_SCHEMES_METHODS[names] = alg + + return alg + + return register + + +import auto_round.auto_schemes.haha # pylint: disable=E0611,E0401 diff --git a/auto_round/auto_schemes/gen_auto_scheme.py b/auto_round/auto_schemes/gen_auto_scheme.py new file mode 100644 index 000000000..12e956eba --- /dev/null +++ b/auto_round/auto_schemes/gen_auto_scheme.py @@ -0,0 +1,97 @@ +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import asdict +from typing import Iterable, Union + +import torch + +from auto_round import AutoScheme +from auto_round.auto_schemes import AUTO_SCHEMES_METHODS +from auto_round.auto_schemes.utils import compute_avg_bits_for_scheme +from auto_round.logger import logger + + +class GenScheme: + """Generate and validate quantization schemes for model layers.""" + + def __init__( + self, + auto_scheme: AutoScheme, # TODO support shared layer + model: torch.nn.Module, + quant_layer_names: Iterable[str], + fixed_layer_scheme: dict[str, dict], + dataset: str = "pile-10k", # TODO use auto-round dataset + device_map: Union[str, torch.device, int, dict, None] = None, + tokenizer=None, + ): + self.auto_scheme = auto_scheme + self.model = model + self.tokenizer = tokenizer + self.quant_layer_names = quant_layer_names + self.fixed_layer_scheme = fixed_layer_scheme + self.dataset = dataset + self.device_map = device_map + self._check_configs() + + def _check_configs(self) -> None: + """Validate auto_scheme configuration and ensure avg_bits target is valid.""" + if isinstance(self.model, torch.nn.Module) and self.tokenizer is None: + raise ValueError("tokenizer must not be None if model is nn.Module") + + if not isinstance(self.dataset, str): + raise TypeError("`dataset` must be a string, got {type(self.dataset).__name__}.") + + min_avg_bit, max_avg_bit = self.compute_avg_bit_range() + target = self.auto_scheme.avg_bits + + logger.info("Average bits range: [%.3f, %.3f], target = %.3f", min_avg_bit, max_avg_bit, target) + if abs(target - min_avg_bit) < 1e-3 or abs(target - max_avg_bit) < 1e-3: + if abs(target - min_avg_bit) < 1e-3: + target = min_avg_bit + else: + target = max_avg_bit + self.auto_scheme.avg_bits = target + + if not (min_avg_bit <= target <= max_avg_bit): + raise ValueError( + f"Target avg_bits={target:.3f} is outside the valid range " f"[{min_avg_bit:.3f}, {max_avg_bit:.3f}]." + ) + + def get_layer_config(self): + method_name = self.auto_scheme.method + method_func = AUTO_SCHEMES_METHODS[method_name] + layer_config = method_func( + self.auto_scheme, + self.model, + self.quant_layer_names, + self.fixed_layer_scheme, + self.dataset, + self.tokenizer, + device_map=self.device_map, + ) + return layer_config + + def compute_avg_bit_range(self) -> tuple[float, float]: + """Compute the min and max average bitwidths among candidate quantization options.""" + avg_bits = [ + compute_avg_bits_for_scheme( + self.model, + self.quant_layer_names, + self.fixed_layer_scheme, + option, + self.auto_scheme.ignore_scale_zp_bits, + )[0] + for option in self.auto_scheme.options + ] + return min(avg_bits), max(avg_bits) diff --git a/auto_round/auto_schemes/utils.py b/auto_round/auto_schemes/utils.py new file mode 100644 index 000000000..73191c40f --- /dev/null +++ b/auto_round/auto_schemes/utils.py @@ -0,0 +1,437 @@ +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import re +from dataclasses import asdict, fields +from typing import Iterable, Union + +import torch +from accelerate import dispatch_model, infer_auto_device_map +from accelerate.utils import get_balanced_memory + +from auto_round.low_cpu_mem import get_module +from auto_round.schemes import QuantizationScheme, preset_name_to_scheme +from auto_round.utils import ( + SUPPORTED_LAYER_TYPES, + check_to_quantized, + get_block_names, + get_layer_features, + is_hpex_available, +) + + +def apply_quant_scheme( + model: torch.nn.Module, + quant_layer_names: Iterable[str], + fixed_layer_scheme: dict[str, dict], + scheme: Union[str, dict], # TODO add scale_dtype +) -> None: + """Apply a quantization scheme to each quantized layer. + + Args: + model: The model whose layers are to be updated. + scheme: The scheme preset name or dictionary to apply. + quant_layer_names: Iterable of layer names to quantize. + fixed_layer_scheme: Dictionary of fixed per-layer quantization schemes. + """ + for name in quant_layer_names: + layer_scheme = fixed_layer_scheme.get(name, scheme) + if isinstance(layer_scheme, str): + layer_scheme = asdict(preset_name_to_scheme(layer_scheme)) + + module = get_module(model, name) + for key, value in layer_scheme.items(): + setattr(module, key, value) + + +def remove_quant_scheme( + model: torch.nn.Module, +) -> None: + """Remove attributes corresponding to the applied quantization scheme. + + Args: + model: The model whose layers are to be cleared. + """ + scheme_keys = [f.name for f in fields(QuantizationScheme)] + ["scale_dtype"] + for n, m in model.named_modules(): + for key in scheme_keys: + if hasattr(m, key): + delattr(m, key) + + +def compute_avg_bits_for_scheme( + model: torch.nn.Module, + quant_layer_names: Iterable[str], + fixed_layer_scheme: dict[str, dict], + scheme: Union[str, dict, None] = None, + ignore_scale_zp_bits: bool = False, +) -> tuple[float, float]: + """Compute the average and total bit usage for the given quantization scheme. + + Args: + model: The model to analyze. + quant_layer_names: Iterable of layer names to include. + fixed_layer_scheme: Dictionary of fixed per-layer quantization schemes. + scheme: Optional scheme to temporarily apply before measuring. + ignore_scale_zp_bits: If True, ignores overhead from scale and zero-points. + + Returns: + A tuple (avg_bits, total_quantized_bits): + avg_bits: Average bitwidth per parameter. + total_quantized_bits: Total quantized bit count. + """ + if scheme is not None: + apply_quant_scheme(model, quant_layer_names, fixed_layer_scheme, scheme) + + total_params = 0 + total_quantized_bits = 0 + + for name in quant_layer_names: + module = get_module(model, name) + # if isinstance(module,torch.nn.Embedding): + # continue + if not hasattr(module, "weight"): + continue + total_params += module.weight.numel() + layer_bits, _ = compute_layer_bits(module, ignore_scale_zp_bits) + total_quantized_bits += layer_bits + + avg_bits = float(total_quantized_bits) / total_params + + if scheme is not None: + remove_quant_scheme(model) + + return avg_bits, total_quantized_bits + + +def compute_avg_bits_for_model(model: torch.nn.Module, ignore_scale_zp_bits: bool = False): + """Compute the average and total bit usage for the entire model. + + Args: + model: The model to analyze. + ignore_scale_zp_bits: If True, ignores overhead from scale and zero-points. + if scheme is not None: + apply_quant_scheme(model, quant_layer_names, fixed_layer_scheme, scheme) + """ + + total_params = 0 + total_quantized_bits = 0 + + for n, module in model.named_modules(): + if not hasattr(module, "bits"): + continue + if not hasattr(module, "weight"): + continue + # if isinstance(module,torch.nn.Embedding): # Tricky setting for Embedding + # continue + total_params += module.weight.numel() + layer_bits, _ = compute_layer_bits(module, ignore_scale_zp_bits) + total_quantized_bits += layer_bits + + avg_bits = float(total_quantized_bits) / total_params + + return avg_bits, total_quantized_bits + + +def compute_layer_bits( + layer: torch.nn.Module, + ignore_scale_zp_bits: bool = False, +) -> tuple[int, float]: + """Compute total and average bitwidth for a single quantized layer. + + Args: + layer: A PyTorch layer with quantization attributes. + ignore_scale_zp_bits: Whether to ignore scale/zero-point overhead. + + Returns: + A tuple (total_bits, avg_bits) representing bit usage. + """ + weight = layer.weight + n_param = weight.numel() + weight_bits = getattr(layer, "bits", 16) + group_size = getattr(layer, "group_size", 128) + super_group_size = getattr(layer, "super_group_size", None) + super_weight_bits = getattr(layer, "super_bits", None) + + # Unquantized layer or ignoring scale/zp overhead + if weight_bits >= 16 or ignore_scale_zp_bits: + if super_weight_bits is not None: # reset gguf 16 bits to 32 bits, TODO gguf q4_0, q4_1 have bug (wenhua) + if weight_bits >= 16: + return 32 * n_param, 32 + + return weight_bits * n_param, min(16, weight_bits) + + in_features, out_features = get_layer_features(layer) + + # Determine number of groups based on group size + if group_size > 0: + n_group = out_features * (in_features + group_size - 1) // group_size + elif group_size == 0: + n_group = 1 + elif group_size == -1: + n_group = out_features + else: + raise ValueError(f"Invalid group_size {group_size}") + + # Compute auxiliary bits (scales, zero-points, or double quantization) + aux_total_bits = 0 + if not super_group_size: + scale_bits = 16 + zp_bits = weight_bits + aux_total_bits = n_group * (scale_bits + zp_bits) + else: + aux_total_bits += n_group * super_weight_bits * 2 + n_super_group = (n_group + super_group_size - 1) // super_group_size + aux_total_bits += n_super_group * 32 * 2 # 32-bit scale and min_v + + total_bits = weight_bits * n_param + aux_total_bits + avg_bits = total_bits / n_param + return total_bits, avg_bits + + +def parse_all_available_device(device_map: Union[str, torch.device, int, dict, None] = None) -> list: + """ + Parse the device map and return a list of all available devices. + + Supported input formats: + - None: Automatically detect all available devices + - int: A single device index (e.g., 0) + - str: Examples: + "cpu" + "cuda:0,cuda:1" + "0,1" (numeric device indices) + - dict: Extract all device values from the dictionary + - torch.device: e.g. torch.device("cuda:0") + + Returns: + list[str]: Normalized device names, e.g., ["cuda:0", "cuda:1"] or ["cpu"] + """ + + # === Step 1. Detect available device types === + device_types = [] + if torch.cuda.is_available(): + device_types.append("cuda") + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device_types.append("xpu") + if hasattr(torch, "hpu") and is_hpex_available(): + device_types.append("hpu") + + # Always include CPU as a fallback + if not device_types: + device_types = ["cpu"] + + # === Step 2. Parse different input formats === + if device_map is None: + # Automatically detect one available device + if "cuda" in device_types: + return ["cuda:0"] + elif "xpu" in device_types: + return ["xpu:0"] + elif "hpu" in device_types: + return ["hpu:0"] + else: + return ["cpu"] + + if isinstance(device_map, torch.device): + # Handle torch.device objects + dev_type = device_map.type + index = device_map.index + if dev_type == "cpu": + return ["cpu"] + if index is None: + index = 0 + return [f"{dev_type}:{index}"] + + if isinstance(device_map, int): + # Integer input → use primary available device type + device_type = device_types[0] + return [f"{device_type}:{device_map}"] if device_type != "cpu" else ["cpu"] + + if isinstance(device_map, str): + # Remove whitespace + device_map = device_map.strip() + if device_map.lower() == "cpu": + return ["cpu"] + + # Split by commas + parts = [x.strip() for x in device_map.split(",") if x.strip()] + parsed = [] + for p in parts: + if p.isdigit(): + # Numeric → assign to first available device type + device_type = device_types[0] + parsed.append(f"{device_type}:{p}" if device_type != "cpu" else "cpu") + else: + parsed.append(p) + return parsed + + if isinstance(device_map, dict): + # Extract all devices recursively from dict values + devices = set() + for v in device_map.values(): + devices.update(parse_all_available_device(v)) + return sorted(devices) + + raise TypeError(f"Unsupported device_map type: {type(device_map)}") + + +# Important Notice This dispatch does not follow dict device_map, just extract all available devices and use them +def dispatch_model_by_all_available_devices( + model: torch.nn.Module, device_map: Union[str, int, dict, None] +) -> torch.nn.Module: + if device_map is None: + device_map = 0 + + no_split_modules = getattr(model, "_no_split_modules", []) + if device_map == "auto": + max_memory = get_balanced_memory( + model, + max_memory=None, + no_split_module_classes=no_split_modules, + ) + device_map = infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=no_split_modules) + model = dispatch_model(model, device_map=device_map) + return model + + devices = parse_all_available_device(device_map) + + if len(devices) == 1: + model.to(devices[0]) + return model + + max_memory = get_balanced_memory( + model, + max_memory=None, + no_split_module_classes=no_split_modules, + ) + + # Filter max_memory with devices + # assume only one GPU model + new_max_memory = {} + for device in devices: + if ":" in device: + device = int(device.split(":")[-1]) + elif device == "cpu": + device = "cpu" + else: + raise ValueError(f"Unsupported device {device} in device_map: {device_map}") + new_max_memory[device] = max_memory[device] + + device_map = infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=no_split_modules) + model = dispatch_model(model, device_map=device_map) + return model + + +def merge_lists_unionfind(list_of_lists): + parent = {} + + def find(x): + while parent[x] != x: + parent[x] = parent[parent[x]] + x = parent[x] + return x + + def union(x, y): + root_x, root_y = find(x), find(y) + if root_x != root_y: + parent[root_y] = root_x + + # 初始化并查集 + for lst in list_of_lists: + for item in lst: + if item not in parent: + parent[item] = item + for i in range(1, len(lst)): + union(lst[0], lst[i]) + + # 收集结果 + groups = {} + for item in parent: + root = find(item) + groups.setdefault(root, []).append(item) + return list(groups.values()) + + +def parse_shared_layers(model: torch.nn.Module, shared_patterns: Iterable[Iterable[str]]) -> list[list[str]]: + """ + Parse shared layer groups based on regex or substring matches. + + Args: + model (torch.nn.Module): The model whose modules will be analyzed. + shared_patterns (Iterable[Iterable[str]]): + Each inner iterable defines one shared group. Each element can be: + - a string: checked by full-name or substring match + - a regex pattern: checked by re.fullmatch or re.search + + Returns: + list[list[str]]: A list of matched shared layer groups. + """ + if not shared_patterns: + return [] + # Retrieve all high-level block names (for example, transformer blocks) + for n, m in model.named_modules(): + m.tmp_name = n # attach global name + + block_names = get_block_names(model, quant_vision=True) + block_names = [item for sublist in block_names for item in sublist] + + # Collect all supported layer names from the model + supported_layer_names = [name for name, module in model.named_modules() if type(module) in SUPPORTED_LAYER_TYPES] + + # Separate groups into those already fully matched and those requiring pattern matching + direct_match_groups = [] + fuzzy_match_groups = [] + for group in shared_patterns: + match_status = {name: (name in supported_layer_names) for name in group} + if all(match_status.values()): + direct_match_groups.append(list(match_status.keys())) + else: + fuzzy_match_groups.append(match_status) + + matched_groups = list(direct_match_groups) + + # Search each block for modules matching remaining patterns + for block_name in block_names: + block_module = get_module(model, block_name) + block_layer_local_names = [ + name for name, module in block_module.named_modules() if type(module) in SUPPORTED_LAYER_TYPES + ] + block_layer_names = [] + for name in block_layer_local_names: + module = get_module(block_module, name) + block_layer_names.append(module.tmp_name) + + for group in fuzzy_match_groups: + matched_layers = set() + for pattern, is_direct in group.items(): + if is_direct: + matched_layers.add(pattern) + continue + + for layer_name in block_layer_names: + # Try regex match first + try: + if re.fullmatch(pattern, layer_name) or re.search(pattern, layer_name): + matched_layers.add(layer_name) + continue + except re.error: + pass # Not a valid regex, fallback to substring matching + + # Substring or partial match + if pattern in layer_name: + matched_layers.add(layer_name) + + if matched_layers: + matched_groups.append(sorted(matched_layers)) + matched_groups = merge_lists_unionfind(matched_groups) + return matched_groups diff --git a/auto_round/autoround.py b/auto_round/autoround.py index 68420d65c..a335b343c 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -26,7 +26,7 @@ MLLMCompressor, ) from auto_round.logger import deprecated, logger -from auto_round.schemes import QuantizationScheme +from auto_round.schemes import AutoScheme, QuantizationScheme from auto_round.utils import is_diffusion_model, is_mllm_model @@ -64,7 +64,7 @@ def __new__( cls, model: Union[torch.nn.Module, str], tokenizer=None, - scheme: Union[str, dict, QuantizationScheme] = "W4A16", + scheme: Union[str, dict, QuantizationScheme, AutoScheme] = "W4A16", layer_config: dict[str, Union[str, dict, QuantizationScheme]] = None, dataset: Union[str, list, tuple, torch.utils.data.DataLoader] = "NeelNanda/pile-10k", iters: int = 200, diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 8ceaefc00..d5110348c 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -26,6 +26,7 @@ import accelerate import torch from accelerate.big_modeling import dispatch_model, infer_auto_device_map +from accelerate.utils import get_balanced_memory from torch import autocast from tqdm import tqdm from transformers import set_seed @@ -36,7 +37,7 @@ from auto_round.export.export_to_gguf.config import GGUF_CONFIG, GGUF_INNER_CONFIG, ModelType from auto_round.logger import logger from auto_round.low_cpu_mem.utils import get_layers_before_block -from auto_round.schemes import QuantizationScheme, preset_name_to_scheme +from auto_round.schemes import AutoScheme, QuantizationScheme, get_gguf_scheme, preset_name_to_scheme from auto_round.sign_sgd import SignSGD from auto_round.special_model_handler import _handle_moe_model from auto_round.utils import ( @@ -76,7 +77,6 @@ get_lm_head_name, get_max_vram, get_module, - get_quant_keys, get_shared_keys, htcore, infer_bits_by_data_type, @@ -92,6 +92,7 @@ mv_module_from_gpu, reset_params, set_amax_for_all_moe_layers, + set_layer_config, set_module, to_device, to_dtype, @@ -130,7 +131,7 @@ def __init__( self, model: Union[torch.nn.Module, str], tokenizer=None, - scheme: Union[str, dict, QuantizationScheme] = "W4A16", + scheme: Union[str, dict, QuantizationScheme, AutoScheme] = "W4A16", layer_config: dict[str, Union[str, dict, QuantizationScheme]] = None, dataset: Union[str, list, tuple, torch.utils.data.DataLoader] = "NeelNanda/pile-10k", iters: int = 200, @@ -203,9 +204,40 @@ def __init__( ... # ... ... } """ - self.scheme = None - self._parse_and_set_scheme(scheme, kwargs) + if isinstance(scheme, AutoScheme): + if len(scheme.options) <= 0: + raise ValueError("options of AutoScheme must not be empty") + options = [] + for option in scheme.options: + new_option = self._parse_and_set_scheme(option, kwargs) + options.append(new_option) + scheme.options = options + for opt in options: + if isinstance(opt, str) and opt == "BF16": + continue + if isinstance(opt, QuantizationScheme): + if opt.bits >= 16 and (opt.act_bits is None or opt.act_bits >= 16): + continue + self.scheme = opt # Choose the first one that not 16 bits + break + + # apply scheme to set default bits + self._parse_and_set_scheme(self.scheme, kwargs) + + self.is_auto_scheme = True + + else: + self.scheme = self._parse_and_set_scheme(scheme, kwargs) + self.is_auto_scheme = False + + scheme_keys = [f.name for f in fields(QuantizationScheme)] + for key in scheme_keys: + kwargs.pop(key, None) + + gguf_scheme_name = get_gguf_scheme(self.scheme) + # GGUF uses fp32 scale dtype as default + scale_dtype = kwargs.pop("scale_dtype", "fp32") if gguf_scheme_name else kwargs.pop("scale_dtype", "fp16") # Extra/legacy kwargs for backward compatibility # Major version releases may pack them with extra configuration options amp = kwargs.pop("amp", True) @@ -218,7 +250,6 @@ def __init__( sampler = kwargs.pop("sampler", "rand") not_use_best_mse = kwargs.pop("not_use_best_mse", False) dynamic_max_gap = kwargs.pop("dynamic_max_gap", -1) - scale_dtype = kwargs.pop("scale_dtype", "fp16") nblocks = kwargs.pop("nblocks", 1) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False) to_quant_block_names: Union[str, list, None] = kwargs.pop("to_quant_block_names", None) @@ -233,13 +264,17 @@ def __init__( self.diffusion = kwargs.pop("diffusion") if "diffusion" in kwargs else False # Scale factor for RAM usage per parameter. self.mem_per_param_scale = kwargs.pop("mem_per_param_scale", None) - fp_layers = kwargs.pop("fp_layers", None) + self.fp_layers = kwargs.pop("fp_layers", "") + self.layer_config = layer_config + self.supported_types = SUPPORTED_LAYER_TYPES + self.inner_supported_types = INNER_SUPPORTED_LAYER_TYPES + self.scale_dtype = convert_dtype_str2torch(scale_dtype) if kwargs: logger.warning(f"unrecognized keys {list(kwargs.keys())} were passed. Please check them.") if "CUBLAS_WORKSPACE_CONFIG" not in os.environ: os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" - # deprecated, default not to use torch.use_deterministic_algorithms + # Deprecated, default not to use torch.use_deterministic_algorithms if not disable_deterministic_algorithms or enable_deterministic_algorithms: if not disable_deterministic_algorithms: logger.warning( @@ -256,8 +291,8 @@ def __init__( if isinstance(model, str): model, tokenizer, low_cpu_mem_usage = llm_load_model( model, - device="cpu", - low_cpu_mem_mode=low_cpu_mem_usage, # always load cpu first + device="cpu", # always load cpu first + low_cpu_mem_mode=low_cpu_mem_usage, ) elif tokenizer is None and not self.diffusion and iters > 0: raise ValueError("A tokenizer must be set for non-str model input") @@ -283,6 +318,77 @@ def __init__( if device_map is None: device_map = 0 + if isinstance(scheme, AutoScheme): + if self.mllm: + logger.info("AutoScheme is not yet supported for multimodal LLMs.") + sys.exit(-1) + + if getattr(model, "is_fp8", False): + logger.info("AutoScheme does not currently support FP8 models.") + sys.exit(-1) + + all_dtypes = [] + for option in scheme.options: + # Skip pure BF16 option + if option == "BF16": + continue + + # Resolve the quantization scheme or data type + dtype = "int" + if isinstance(option, str): + option = preset_name_to_scheme(option) + + if isinstance(option, QuantizationScheme): + dtype = option.data_type + elif isinstance(option, dict): + dtype = option.get("data_type", "int") + + all_dtypes.append(dtype) + + # Check for mixed data types + unique_dtypes = set(all_dtypes) + if len(unique_dtypes) > 1: + logger.warning( + "Models with mixed data_types " + "cannot yet be exported to real formats except GGUF. " + "Please save the model using the `fake` format for now." + ) + + layer_config, self.has_qlayer_outside_block = set_layer_config( + self.model, + self.layer_config, + self.scheme, + self.scale_dtype, + self.supported_types, + self.inner_supported_types, + self.quant_block_list, + self.fp_layers, + self.quant_lm_head, + enable_gguf_official_mixed=False, + is_mllm=self.mllm, + ) + quant_layer_names = layer_config.keys() + scheme_keys = {f.name for f in fields(QuantizationScheme)} + fixed_layer_scheme_new = { + k: {key: v[key] for key in scheme_keys & v.keys()} + for k, v in layer_config.items() + if v.get("fixed_by_user", False) + } + + # mainly using quant_layers and fixed by users + from auto_round.auto_schemes.gen_auto_scheme import GenScheme + + gen_scheme = GenScheme( + scheme, + self.model, + quant_layer_names, + fixed_layer_scheme_new, + dataset, + device_map=device_map, + tokenizer=self.tokenizer, + ) + self.layer_config = gen_scheme.get_layer_config() + # Set device, must place after model loading self._set_device(device_map) @@ -296,15 +402,6 @@ def __init__( self.device_map = None self._set_device_map_in_blocks(self.device_map) - not_quantize_layer_names = get_fp_layer_names(self.model, fp_layers) - if len(not_quantize_layer_names) > 0: - logger.info(f"{not_quantize_layer_names} will not be quantized.") - if layer_config is None: - layer_config = {} - for name in not_quantize_layer_names: - layer_config[name] = {"bits": 16, "act_bits": 16, "data_type": "float", "act_data_type": "float"} - self._parse_layer_config(layer_config) # must place after model init - # Tuning hyperparameters self.seed = seed set_seed(self.seed) @@ -341,7 +438,6 @@ def __init__( if self.static_kv_dtype is not None: logger.warning("The static kv is experimental and currently has limited support.") - self.scale_dtype = convert_dtype_str2torch(scale_dtype) self._set_amp_dtype() self.cache_device = torch.device("cpu") if self.low_gpu_mem_usage else self.device if self.act_bits <= 8 and self.amp_dtype == torch.float16: @@ -353,8 +449,6 @@ def __init__( logger.info(f"using {self.model.dtype} for quantization tuning") # Some helpers - self.supported_types = SUPPORTED_LAYER_TYPES - self.inner_supported_types = INNER_SUPPORTED_LAYER_TYPES if "hpu" in str(self.device): self.inner_supported_types = tuple(x for x in INNER_SUPPORTED_LAYER_TYPES if x != "FP8Linear") self.batch_dim = None @@ -369,7 +463,7 @@ def __init__( import habana_frameworks.torch.core as htcore # pylint: disable=E0401 import habana_frameworks.torch.hpu as hthpu # pylint: disable=E0401] - def _set_device(self, device_map): + def _set_device(self, device_map: Union[str, torch.device, int, dict]) -> None: if hasattr(self, "device") and self.device is not None: return if isinstance(device_map, (str, torch.device, int)): @@ -393,65 +487,16 @@ def _set_device(self, device_map): else: raise TypeError(f"device_map should be [str, torch.device, int, dict], but got {type(device_map)}") - def _parse_layer_config(self, layer_config: dict[str, Union[str, dict, QuantizationScheme]]) -> None: - """Parse and set the layer-wise quantization configuration.""" - # Some other quantization configs - self.layer_config = copy.deepcopy(layer_config) if layer_config is not None else {} - scheme_keys = {f.name for f in fields(QuantizationScheme)} - - for key, item in self.layer_config.items(): - if isinstance(item, str): - config = asdict(preset_name_to_scheme(item.upper())) - elif isinstance(item, QuantizationScheme): - config = asdict(item) - elif isinstance(item, dict): - invalid_keys = set(item) - scheme_keys - if invalid_keys: - raise ValueError( - f"Invalid keys {invalid_keys} in layer_config for layer '{key}', " - f"only {scheme_keys} are supported" - ) - config = dict(item) - - # Drop None values - config = {k: v for k, v in config.items() if v is not None} - self.layer_config[key] = config - - if not self.quant_lm_head or (isinstance(self.scheme, str) and self.scheme.lower().startswith("gguf")): - return - for n, _ in self.model.named_modules(): - lm_head_layer_name = n - - if ( - hasattr(self.model, "config") - and self.model.config.tie_word_embeddings - and hasattr(self.model, "_tied_weights_keys") - ): - tied_keys = self.model._tied_weights_keys - for item in tied_keys: - if lm_head_layer_name in item: # TODO extend to encoder-decoder layer, seq classification model - self.quant_lm_head = False - logger.warning( - "reset `quant_lm_head` to `False` as quantizing lm_head with tied weights has not been " - "supported currently" - ) - break - - lm_head_layer_config = self.layer_config[lm_head_layer_name] if lm_head_layer_name in self.layer_config else {} - - for key in scheme_keys: - if key not in lm_head_layer_config: - lm_head_layer_config[key] = getattr(self, key) - - def _parse_and_set_scheme(self, scheme: Union[str, dict, QuantizationScheme], kwargs) -> None: + def _parse_and_set_scheme(self, scheme: Union[str, dict, QuantizationScheme], kwargs) -> QuantizationScheme: """Parse and set the quantization scheme.""" + res = "" if isinstance(scheme, QuantizationScheme): scheme = asdict(scheme) elif isinstance(scheme, dict): scheme = scheme elif isinstance(scheme, str): + res = scheme # gguf:q4_k_s and gguf_q4_k_m has the same dict scheme, but the result is different scheme = scheme.upper() - self.scheme = scheme scheme = asdict(preset_name_to_scheme(scheme)) scheme_keys = [f.name for f in fields(QuantizationScheme)] for key in scheme_keys: @@ -459,7 +504,7 @@ def _parse_and_set_scheme(self, scheme: Union[str, dict, QuantizationScheme], kw setattr(self, key, kwargs[key]) else: setattr(self, key, scheme.get(key, None)) - kwargs.pop(key, None) + # kwargs.pop(key, None) if self.act_dynamic is None: self.act_dynamic = True @@ -492,11 +537,17 @@ def _parse_and_set_scheme(self, scheme: Union[str, dict, QuantizationScheme], kw f" match the specified 'act_bits' setting. Resetting 'act_bits' to {tmp_act_bits}." ) if tmp_act_bits is not None and tmp_act_bits < 16: - for supported_dtype in SUPPORTED_DTYPES: # to easily handle dtype mx_fp4 and layer_config={xxx:{bits:8}} + for supported_dtype in SUPPORTED_DTYPES: # To easily handle dtype mx_fp4 and layer_config={xxx:{bits:8}} if self.act_data_type.startswith(supported_dtype): - if supported_dtype + str(tmp_act_bits) == self.act_data_type: # could not replace FP8_e4m3 + if supported_dtype + str(tmp_act_bits) == self.act_data_type: # Could not replace FP8_e4m3 self.act_data_type = supported_dtype break + for key in scheme_keys: + scheme[key] = getattr(self, key) + if res and QuantizationScheme.from_dict(scheme) == preset_name_to_scheme(res): + return res + else: + return QuantizationScheme.from_dict(scheme) def _adjust_torch_compile(self, enable_torch_compile: bool) -> None: """Sets the torch compile configuration for the tuning.""" @@ -740,20 +791,20 @@ def _check_compatibility(self) -> None: " We are likely to release new algorithm for certain configurations in the future." ) - # Check group_size 32 for auto_round - if ( - self.data_type == "int" - and hasattr(self, "formats") - and any(key in fmt for fmt in self.formats for key in ("auto_round", "auto_gptq", "auto_awq")) - ): - for n, m in self.model.named_modules(): - if type(m) in self.supported_types: - if m.weight.shape[0] % 32 != 0 or m.weight.shape[1] % 32 != 0: - self.layer_config[n] = {"bits": 16} - logger.info( - f"{n} will not be quantized due to its shape not being divisible by 32," - " resulting in an exporting issue to autogptq" - ) + # # Check group_size 32 for auto_round + # if ( + # self.data_type == "int" + # and hasattr(self, "formats") + # and any(key in fmt for fmt in self.formats for key in ("auto_round", "auto_gptq", "auto_awq")) + # ): + # for n, m in self.model.named_modules(): + # if type(m) in self.supported_types: + # if m.weight.shape[0] % 32 != 0 or m.weight.shape[1] % 32 != 0: + # self.layer_config[n] = {"bits": 16} + # logger.info( + # f"{n} will not be quantized due to its shape not being divisible by 32," + # " resulting in an exporting issue to autogptq" + # ) if ( self.seqlen is not None @@ -802,19 +853,26 @@ def remove_duplicates(lst): formats = format.replace("q*_", f"q{self.bits}_").replace(" ", "").split(",") formats = remove_duplicates(formats) # need the keep origin order - if isinstance(self.scheme, str) and self.scheme.lower().startswith("gguf"): + gguf_format_name = get_gguf_scheme(self.scheme) + + if gguf_format_name: for i in range(len(formats)): - if formats[i] != "fake" and formats[i] != self.scheme.lower(): + if formats[i] != "fake" and formats[i] != gguf_format_name.lower(): logger.warning( - f"reset format {formats[i]} to {self.scheme.lower()} " - f"since scheme {self.scheme} can only be exported to format {self.scheme.lower()}" + f"reset format {formats[i]} to {gguf_format_name.lower()} " + f"since scheme {gguf_format_name} can only be exported to format {gguf_format_name.lower()}" ) - formats[i] = self.scheme.lower() + formats[i] = gguf_format_name.lower() _gguf_args_check(self, formats, model_type=ModelType.TEXT) if self.mllm: _gguf_args_check(self, formats, model_type=ModelType.MMPROJ) + for f in formats: + if f.startswith("gguf"): + self.scheme = f.upper() + break + for format_ in formats: if format_ not in SUPPORTED_FORMATS: logger.error(f"Unsupported format {format_}, please choose from {SUPPORTED_FORMATS}") @@ -1298,92 +1356,6 @@ def get_imatrix_hook(module, input, output): for hook in hooks: hook.remove() - def _check_need_to_quantize_lm_head_embedding(self) -> bool: - """Checks if LM head and embedding layers need quantization for GGUF format. - - This function inspects the current model's formats and determines whether - it needs to apply quantization settings to the embedding and LM head layers. - The function modifies `self.layer_config` in-place and updates the model modules. - - Returns: - bool: True if the LM head needs quantization, otherwise False. - - Raises: - NotImplementedError: If multiple non-fake GGUF formats are specified. - """ - gguf_scheme = False - if isinstance(self.scheme, str) and "gguf" in self.scheme.lower(): - gguf_scheme = True - - if not hasattr(self, "formats") and not gguf_scheme: - return False - - has_gguf: bool = gguf_scheme or any("gguf" in fmt for fmt in self.formats) - if not has_gguf: - return False - if hasattr(self, "formats"): - formats: list[str] = [fmt for fmt in self.formats if "fake" not in fmt] - if not (len(formats) == 1 and "gguf" in formats[0]): - raise NotImplementedError("Only one GGUF format can be set at a time.") - target_format: str = formats[0] - - else: - target_format = self.scheme.lower() - - tie_word_embeddings: bool = getattr(getattr(self.model, "config", None), "tie_word_embeddings", True) - for name, module in self.model.named_modules(): - if isinstance(module, torch.nn.Embedding): - key: str = "lm_head" if tie_word_embeddings else "embedding" - config: dict[str, Any] = GGUF_INNER_CONFIG[GGUF_CONFIG[target_format][key]] - self._apply_config_to_layer(name, config, True) - - if not tie_word_embeddings: - lm_head_name: str = get_lm_head_name(self.model) - config: dict[str, Any] = GGUF_CONFIG[GGUF_CONFIG[target_format]["lm_head"]] - check_fixed_by_user = ( - self.layer_config[lm_head_name].get("fixed_by_user", False) - if lm_head_name in self.layer_config - else None - ) - self._apply_config_to_layer(lm_head_name, config, check_fixed_by_user=check_fixed_by_user) - return True - - return False - - def _apply_config_to_layer( - self, - layer_name: str, - config: dict[str, Any], - check_fixed_by_user: bool = False, - ) -> None: - """Applies GGUF quantization configuration to a given layer. - - Args: - layer_name (str): Name of the layer to configure. - config (dict[str, Any]): GGUF layer configuration. - check_fixed_by_user (bool): If True, preserve user-defined settings. - """ - act_bits: int = 16 - scale_dtype: Any = self.scale_dtype - keys: list[str] = ["bits", "group_size", "super_bits", "super_group_size", "data_type", "sym"] - - self.layer_config[layer_name] = self.layer_config.get(layer_name, {}) - - for key in keys: - if ( - key in self.layer_config[layer_name] - and check_fixed_by_user - # and self.layer_config[layer_name].get("fixed_by_user", False) - ): - continue - self.layer_config[layer_name][key] = config.get(key) - setattr(get_module(self.model, layer_name), key, config.get(key)) - - self.layer_config[layer_name]["act_bits"] = act_bits - self.layer_config[layer_name]["scale_dtype"] = scale_dtype - setattr(get_module(self.model, layer_name), "act_bits", act_bits) - setattr(get_module(self.model, layer_name), "scale_dtype", scale_dtype) - def _quantize_layer_via_rtn(self, name: str) -> None: """Quantizes a layer using RTN (Round-To-Nearest) if available. @@ -1648,9 +1620,7 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str]) if self.device_map is not None: accelerate.hooks.remove_hook_from_submodules(block) - if ( - is_nv_fp(self.act_data_type) and any("nv_fp" in format_ for format_ in self.formats) - ) or is_static_wfp8afp8(self): + if (is_nv_fp(self.act_data_type)) or is_static_wfp8afp8(self): # enable moe experts act_max automatic generation for Linear set_amax_for_all_moe_layers(block, attr_name="act_max") # Normalize imatrix and quantize layers @@ -1696,32 +1666,55 @@ def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]: Returns: The quantized model and layer configurations. """ - for n, m in self.model.named_modules(): + for n, m in self.model.named_modules(): # TODO check if could removed m.tmp_name = n self._check_compatibility() formats = self.formats if hasattr(self, "formats") else None # It is best to modify the model structure in the quantize function and check the format, # because it may cause the gguf format to not be exported normally. self.model = _handle_moe_model(self.model, formats=formats) - self.has_qlayer_outside_block = self._set_layerwise_config(self.layer_config) + + # TODO check scale_dtype + if not self.is_auto_scheme: + self.layer_config, self.has_qlayer_outside_block = set_layer_config( + self.model, + self.layer_config, + self.scheme, + self.scale_dtype, + self.supported_types, + self.inner_supported_types, + self.quant_block_list, + self.fp_layers, + self.quant_lm_head, + enable_gguf_official_mixed=True, + is_mllm=self.mllm, + ) + else: + # for n, scheme in self.layer_config.items(): + # module = get_module(self.model, n) + # if not isinstance(scheme, dict): + # raise ValueError("scheme return by scheme should be dict") + # for key, item in scheme.items(): + # setattr(module, key, item) + # # set_extra scale_dtype + # module.scale_dtype = self.scale_dtype + self.layer_config, self.has_qlayer_outside_block = set_layer_config( + self.model, + self.layer_config, + self.scheme, + self.scale_dtype, + self.supported_types, + self.inner_supported_types, + self.quant_block_list, + self.fp_layers, + self.quant_lm_head, + enable_gguf_official_mixed=False, + is_mllm=self.mllm, + ) + if not hasattr(self, "formats"): logger.warning("this API is deprecated, please use `quantize_and_save` instead") else: - only_gguf = True - for format_ in self.formats: - if not ("gguf" in format_ or "fake" in format_): - only_gguf = False - break - if len(self.formats) == 1 and self.formats[0] == "fake": - only_gguf = False - if only_gguf: - self.layer_config, gguf_format_config = get_layer_config_by_gguf_format( - self.layer_config, self.formats, self.model, model_type=ModelType.TEXT - ) - if self.mllm: - self.layer_config, gguf_format_config = get_layer_config_by_gguf_format( - self.layer_config, self.formats, self.model, model_type=ModelType.MMPROJ - ) # Determine if immediate packing is required formats = self.formats if ( @@ -1827,7 +1820,7 @@ def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]: cost_time = end_time - self.start_time logger.info(f"quantization tuning time {cost_time}") - ## dump a summary + # Dump a summary quantized_layers = [] unquantized_layers = [] for n, m in self.model.named_modules(): @@ -1924,141 +1917,6 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None: del layer_input clear_memory(q_layer_input) - def _set_layerwise_config(self, layer_config: dict) -> bool: - """ - Sets the layer-wise configuration based on the provided `layer_config`. - By default, only quantize layers in blocks. - - Args: - layer_config (dict): The configuration dictionary for each layer containing various configuration options. - - Returns: - bool: Returns True if there are quantized layers outside the blocks (e.g., lm-head), - otherwise returns False. - """ - # Get the names of layers in quantization blocks - supported_types = self.supported_types - layers_in_blocks = get_layer_names_in_block( - self.model, supported_types, self.quant_block_list, self.inner_supported_types - ) - ##process regex in layer_config - all_supported_layer_names = [] - # List of configuration keys - keys = get_quant_keys() - - for n, m in self.model.named_modules(): - # Delete previous configuration to avoid conflicts with prior tuning - for key in keys: - if hasattr(m, key): - delattr(m, key) - - if not isinstance(m, supported_types) and m.__class__.__name__ not in self.inner_supported_types: - continue - all_supported_layer_names.append(n) - - names_in_layer_config = list(layer_config.keys()) - for name in names_in_layer_config: - if name in all_supported_layer_names: - continue - matched_names = [] - for layer_name in all_supported_layer_names: - if re.search(re.compile(name), layer_name) is not None: - matched_names.append(layer_name) - if len(matched_names) > 0: - val = layer_config[name] - layer_config.pop(name) - for match_name in matched_names: - layer_config[match_name] = val - else: - tmp_m = get_module(self.model, name) - if not isinstance(tmp_m, torch.nn.Embedding): # TODO not good code style - raise ValueError(f"key {name} in layer_config is invalid, please have a double check") - - has_qlayer_outside_block = False # Flag to track if there are quantized layers outside blocks (e.g., lm-head) - - # Iterate through all modules in the model - is_gguf = hasattr(self, "formats") and any("gguf" in format_ for format_ in self.formats) - for n, m in self.model.named_modules(): - # Skip unsupported types - if type(m) not in supported_types and m.__class__.__name__ not in self.inner_supported_types: - if n in self.layer_config: - if not isinstance(m, torch.nn.Embedding): - logger.warning(f"{n} is not supported, layer_config {n}: {layer_config[n]} will be ignored.") - self.layer_config.pop(n) - continue - if not is_gguf: - if not check_to_quantized(layer_config[n]): - self.layer_config.pop(n) - continue - else: - continue - - # If the layer is not in the config and is part of a quantization block, use default configuration - if n not in layer_config.keys() and n in layers_in_blocks: - layer_config[n] = {} - for key in keys: - layer_config[n][key] = getattr(self, key) - - # If the layer is partially configured, fill in missing values - elif n in layer_config.keys(): - if "data_type" in layer_config[n] and "bits" not in layer_config[n]: - tmp_bits = infer_bits_by_data_type(layer_config[n]["data_type"]) - if tmp_bits is not None and tmp_bits != self.bits: - logger.warning( - f"'data_type' do not match the specified 'bits' setting for {n}." - f" Resetting 'bits' to {tmp_bits}." - ) - layer_config[n]["bits"] = tmp_bits - if "act_data_type" in layer_config[n] and "act_bits" not in layer_config[n]: - tmp_bits = infer_bits_by_data_type(layer_config[n]["act_data_type"]) - if tmp_bits is not None and tmp_bits != self.act_bits: - logger.warning( - f"'act_data_type' do not match the specified 'act_bits' setting for {n}." - f" Resetting 'act_bits' to {tmp_bits}." - ) - layer_config[n]["act_bits"] = tmp_bits - - for key in keys: - if key not in layer_config[n].keys(): - layer_config[n][key] = getattr(self, key) - layer_config[n]["fixed_by_user"] = True - - # If the layer is not in the config and not part of a quantization block, - # use default configuration and set specific values - else: - layer_config[n] = {} - for key in keys: - layer_config[n][key] = getattr(self, key) - layer_config[n]["bits"] = 16 - layer_config[n]["act_bits"] = 16 - - if n in layers_in_blocks: - layer_config[n]["in_blocks"] = True - else: - layer_config[n]["in_blocks"] = False - - # If the layer is outside a block and requires quantization, mark it as a quantized layer outside the block - if ( - n not in layers_in_blocks - and check_to_quantized(layer_config[n]) - and not isinstance(m, torch.nn.Embedding) - ): - has_qlayer_outside_block = True - - in_features, out_features = get_layer_features(m) - if in_features <= layer_config[n]["group_size"]: - layer_config[n]["group_size"] = -1 - - # Apply the configuration to the corresponding layer in the model - for key in keys: - setattr(m, key, layer_config[n][key]) - need_to_quantize_lm_head = self._check_need_to_quantize_lm_head_embedding() - if need_to_quantize_lm_head: - has_qlayer_outside_block = True - - # Return whether there are quantized layers outside the blocks - return has_qlayer_outside_block - @torch.no_grad() def _get_block_outputs( self, @@ -2250,8 +2108,12 @@ def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, l if str(self.model.device) == "cpu" and ( self.device.startswith("xpu") or self.device.startswith("cuda") ): - max_memory = get_max_vram() # TODO model is not evenly split no_split_modules = getattr(self.model, "_no_split_modules", []) + max_memory = get_balanced_memory( + self.model, + max_memory=None, + no_split_module_classes=no_split_modules, + ) device_map = infer_auto_device_map( self.model, max_memory=max_memory, no_split_module_classes=no_split_modules ) diff --git a/auto_round/data_type/gguf.py b/auto_round/data_type/gguf.py index 6456b817f..6aa19a3d5 100644 --- a/auto_round/data_type/gguf.py +++ b/auto_round/data_type/gguf.py @@ -337,7 +337,7 @@ def quant_tensor_gguf_asym_dq( if bits == 2: quant_weights = torch.abs(tensor) elif bits == 4 or bits == 5: - sigma2 = torch.sum(torch.pow(tensor, 2), dim=-1, keepdim=True) / 32 ##Note 32 is different from QK_K + sigma2 = torch.sum(torch.pow(tensor, 2), dim=-1, keepdim=True) / 32 # Note 32 is different from QK_K av_x = torch.sqrt(sigma2) quant_weights = torch.abs(tensor) + av_x params = search_kwargs[bits] diff --git a/auto_round/data_type/register.py b/auto_round/data_type/register.py index 12c4406a4..fca259ed6 100644 --- a/auto_round/data_type/register.py +++ b/auto_round/data_type/register.py @@ -22,8 +22,7 @@ def register_dtype(names): Decorator function used before a Pattern subclass. Args: - cls (class): The subclass of register. - name: A string. Define the export type. + names: A string. Define the export type. Returns: cls: The class of register. diff --git a/auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py b/auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py index c4a02f673..eaf3ad9ae 100644 --- a/auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py +++ b/auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py @@ -174,7 +174,7 @@ def save_quantized_as_fp(output_dir, inplace=True, **kwargs): for n, m in model.named_modules(): if type(m) in SUPPORTED_LAYER_TYPES: layer = m - if layer.act_bits < 8 and not getattr(layer, "input_global_scale", None): + if hasattr(layer, "act_bits") and layer.act_bits < 8 and not getattr(layer, "input_global_scale", None): assert hasattr(layer, "act_max") from auto_round.data_type.nvfp import calculate_gparam @@ -198,7 +198,7 @@ def save_quantized_as_fp(output_dir, inplace=True, **kwargs): for layer_name in layer_config: if ( not layer_config[layer_name]["in_blocks"] and layer_config[layer_name]["bits"] <= 8 - ): ##lm head ##TODO fix act and so on + ): ##lm head # TODO fix act and so on extra_config[layer_name] = {} extra_config[layer_name]["bits"] = layer_config[layer_name]["bits"] extra_config[layer_name]["data_type"] = layer_config[layer_name]["data_type"] diff --git a/auto_round/export/export_to_llmcompressor/export_to_fp.py b/auto_round/export/export_to_llmcompressor/export_to_fp.py index 1a56ebf12..633593075 100644 --- a/auto_round/export/export_to_llmcompressor/export_to_fp.py +++ b/auto_round/export/export_to_llmcompressor/export_to_fp.py @@ -169,7 +169,7 @@ def save_quantized_as_fp(output_dir, inplace=True, **kwargs): for n, m in model.named_modules(): if type(m) in SUPPORTED_LAYER_TYPES: layer = m - if layer.act_bits < 8 and not getattr(layer, "input_global_scale", None): + if hasattr(layer, "act_bits") and layer.act_bits < 8 and not getattr(layer, "input_global_scale", None): assert hasattr(layer, "act_max") from auto_round.data_type.nvfp import calculate_gparam diff --git a/auto_round/schemes.py b/auto_round/schemes.py index a5c5975c9..789cbe1cb 100644 --- a/auto_round/schemes.py +++ b/auto_round/schemes.py @@ -14,9 +14,9 @@ import copy from copy import deepcopy from dataclasses import dataclass, fields -from typing import Generator, List, Optional +from typing import Iterable, Optional, Union -__all__ = ["QuantizationScheme", "preset_name_to_scheme"] +__all__ = ["QuantizationScheme", "get_gguf_scheme", "preset_name_to_scheme", "AutoScheme"] @dataclass @@ -38,7 +38,7 @@ def from_dict(cls, config: dict): return cls(**config) @classmethod - def get_attributes(cls: "QuantizationScheme") -> List[str]: + def get_attributes(cls: "QuantizationScheme") -> list[str]: return [field.name for field in fields(cls)] def __getitem__(self, key: str): @@ -72,7 +72,15 @@ def get(self, key: str, default=None): def __eq__(self, other: "QuantizationScheme") -> bool: if not isinstance(other, QuantizationScheme): return False + skip_act_check = False + self_act_bits = 16 if self.act_bits is None else self.act_bits + other_act_bits = 16 if other.act_bits is None else other.act_bits + if self_act_bits == other_act_bits and other_act_bits >= 16: + skip_act_check = True + for field in self.get_attributes(): + if skip_act_check and field.startswith("act_"): + continue if getattr(self, field) != getattr(other, field): return False return True @@ -180,6 +188,7 @@ def is_preset_scheme(name: str) -> bool: } ) + # FP8 = asdict(QuantArgs.from_dict({ # "bits": 8, # "group_size": 128, @@ -201,6 +210,18 @@ def is_preset_scheme(name: str) -> bool: } ) +# For AutoScheme 16 bits options +BF16 = QuantizationScheme.from_dict( + { + "bits": 16, + "group_size": 128, + "data_type": "fp", + "act_bits": 16, + "act_data_type": "fp", + } +) + + PRESET_SCHEMES = { "W4A16": W4A16, "W2A16": W2A16, @@ -211,6 +232,7 @@ def is_preset_scheme(name: str) -> bool: "NVFP4": NVFP4, "FPW8A16": FPW8A16, "FP8_STATIC": FP8_STATIC, + "BF16": BF16, } from auto_round.export.export_to_gguf.config import GGUF_CONFIG @@ -220,3 +242,39 @@ def is_preset_scheme(name: str) -> bool: value.pop("embedding", None) value.pop("lm_head", None) PRESET_SCHEMES[key.upper()] = QuantizationScheme.from_dict(value) + + +def get_gguf_scheme(scheme: Union[str, QuantizationScheme]) -> str: + if isinstance(scheme, str) and scheme.upper().startswith("GGUF"): + return scheme + if isinstance(scheme, str): + return "" + for key, val in PRESET_SCHEMES.items(): + # For q40 or q4_1 we only support it with str scheme, otherwise it will be matched incorrectly with W4G32 + if not key.upper().startswith("GGUF") or ("0" in key or "1" in key): + continue + equal = True + for scheme_key in val.keys(): + if val[scheme_key] is not None and val[scheme_key] != scheme.get(scheme_key, None): + equal = False + break + if equal: + return key + return "" + + +@dataclass +class AutoScheme: + avg_bits: float + options: Union[str, list[Union[QuantizationScheme, str]], tuple[Union[QuantizationScheme, str], ...]] + shared_layers: Optional[Iterable[Iterable[str]]] = None + method: str = "default" + ignore_scale_zp_bits: bool = False + nsamples: Optional[int] = None + seqlen: Optional[int] = None + dataset: Optional[str] = None # Import Notice no comma for each item + + def __post_init__(self): + if isinstance(self.options, str): + options = self.options.upper().replace(" ", "") + self.options = options.split(",") diff --git a/auto_round/utils.py b/auto_round/utils.py index 90f161df7..37460f119 100644 --- a/auto_round/utils.py +++ b/auto_round/utils.py @@ -21,6 +21,7 @@ import re import sys from collections import UserDict +from dataclasses import asdict, fields from enum import Enum from functools import lru_cache from pathlib import Path @@ -29,12 +30,13 @@ import cpuinfo import torch import transformers +from accelerate.utils import get_balanced_memory from packaging import version from torch.amp import autocast from auto_round.export.export_to_gguf.config import GGML_QUANT_SIZES, GGUF_CONFIG, GGUF_INNER_CONFIG, QK_K, ModelType from auto_round.logger import logger -from auto_round.schemes import QuantizationScheme +from auto_round.schemes import QuantizationScheme, get_gguf_scheme, preset_name_to_scheme SHARED_CACHE_KEYS = ("position_ids", "cache_position", "position_embeddings") @@ -782,8 +784,11 @@ def check_memory_availability(device, inputs, weight, org_seqlen, org_bs): def get_layer_names_in_block( - model, supported_types=(torch.nn.Linear, transformers.pytorch_utils.Conv1D), quant_block_list=None, class_names=None -): + model: torch.nn.Module, + supported_types=(torch.nn.Linear, transformers.pytorch_utils.Conv1D), + quant_block_list: list = None, + class_names: tuple = None, +) -> list[str]: """Retrieves the names of layers within each block of the model. Returns: @@ -807,7 +812,6 @@ def get_layer_names_in_block( if hasattr(m, "bk_tmp_name"): layers_in_block.append(m.bk_tmp_name) delattr(m, "bk_tmp_name") - return layers_in_block @@ -1064,7 +1068,7 @@ def can_pack_with_numba(): # pragma: no cover return True -def get_fp_layer_names(model, fp_layers): +def get_fp_layer_names(model: torch.nn.Module, fp_layers: str): """Identifies and returns layers in the model to exclude from quantization. This function processes a comma-separated list of fully precision (FP) layers, @@ -1886,9 +1890,9 @@ def _gguf_type_fallback(gguf_type): ##https://github.com/ggml-org/llama.cpp/blob/9e31bec4fd53634c9e5b04650488a09a055f5dab/src/llama-quant.cpp#L129 -def get_layer_config_by_gguf_format(layer_config, gguf_format, model, model_type=ModelType.TEXT): - # TODO: support for other format later - target_gguf_format = next((fmt for fmt in gguf_format if fmt != "fake"), None) +def get_layer_config_by_gguf_format(layer_config, target_gguf_format: str, model, model_type=ModelType.TEXT): + # # TODO: support for other format later + # target_gguf_format = next((fmt for fmt in gguf_format if fmt != "fake"), None) import gguf # pylint: disable=E0401 @@ -1982,6 +1986,34 @@ def _set_config(config, target_config): ) new_type = new_type[:bits_index] + target_bits + new_type[bits_index + 1 :] else: + config_tmp = config.copy() + scheme_keys = [f.name for f in fields(QuantizationScheme)] + for key in config.keys(): + if key not in scheme_keys: + config_tmp.pop(key, None) + matched_scheme = get_gguf_scheme(QuantizationScheme.from_dict(config_tmp)) # check matched + if not matched_scheme: + if config.get("super_group_size", None) is not None: + new_type = new_type[:bits_index] + str(config["bits"]) + "_k" + if config.get("super_group_size", None) is None or new_type not in GGUF_INNER_CONFIG: + if config.get("sym", True): + new_type = new_type[:bits_index] + str(config["bits"]) + "_0" + if new_type not in GGUF_INNER_CONFIG: + new_type = new_type[:bits_index] + str(config["bits"]) + "_1" + if not config.get("sym", True): + new_type = new_type[:bits_index] + str(config["bits"]) + "_1" + if new_type not in GGUF_INNER_CONFIG: + new_type = new_type[:bits_index] + str(config["bits"]) + "_0" + if new_type not in GGUF_INNER_CONFIG: + raise ValueError( + f"the setting in layer_config {layer_name} " + f"could not match any supported gguf format, please have a check." + ) + else: + logger.warning_once( + f"the setting in layer_config {layer_name} " + f"could not match any supported gguf format, reset to {new_type}" + ) new_type = new_type[:bits_index] + str(config["bits"]) + new_type[bits_index + 1 :] new_type = _search_gguf_type(new_type) if new_type is None: @@ -2233,9 +2265,9 @@ def get_reciprocal(tensor): def check_need_act_calibration( - is_act_dynamic: Union[bool, None], act_data_type: Union[str, None] = None, act_bits: int = 16 + is_act_dynamic: Union[bool, None], act_data_type: Union[str, None] = None, act_bits: Union[int, None] = 16 ) -> bool: - if act_bits > 8: + if act_bits is None or act_bits > 8: return False # None is dynamic if is_act_dynamic is not None and not is_act_dynamic: @@ -2325,8 +2357,8 @@ def convert_fp8_layer_to_linear(layer, dtype=torch.bfloat16): new_layer = torch.nn.Linear(layer.in_features, layer.out_features, bias=layer.bias is not None, dtype=dtype) if layer.bias is not None: new_layer.bias.data.copy_(layer.bias.data.to(dtype=dtype)) - - keys = get_quant_keys() + ["tmp_name"] + scheme_keys = (f.name for f in fields(QuantizationScheme)) + keys = tuple(scheme_keys) + ("tmp_name", "scale_dtype") for key in keys: setattr(new_layer, key, getattr(layer, key, None)) @@ -2355,24 +2387,6 @@ def convert_fp8_model_to_16b_model(model, dtype=torch.bfloat16): return model -def get_quant_keys(): - keys = [ - "bits", - "group_size", - "sym", - "data_type", - "scale_dtype", - "act_bits", - "act_group_size", - "act_sym", - "act_dynamic", - "act_data_type", - "super_bits", - "super_group_size", - ] - return keys - - def out_of_vram(error_msg): error_msg = str(error_msg) # CUDA @@ -2805,6 +2819,199 @@ def is_mllm_model(model_or_path: Union[str, torch.nn.Module]): return False +def set_layer_config( + model: torch.nn.Module, + layer_config: dict[str, Union[str, dict, "QuantizationScheme"]], + default_scheme: Union[str, "QuantizationScheme"], + default_scale_dtype: torch.dtype | str, + supported_types: tuple, + inner_supported_types: tuple, + quant_block_list=None, + fp_layers: str = "", + quant_lm_head: bool = False, + enable_gguf_official_mixed: bool = True, + is_mllm: bool = False, +) -> tuple[dict, bool]: + """ + Normalize, validate, and expand layer-specific quantization configs. + Returns (final_layer_config, has_quant_layer_outside_block) + """ + + from auto_round.schemes import get_gguf_scheme + + # ---- helpers ------------------------------------------------- + def dispatch_layer_config(layer_config: dict[str, dict]) -> None: + """Assign scheme values as attributes to matched modules.""" + for layer_name, scheme in layer_config.items(): + module = get_module(model, layer_name) + for attr, value in scheme.items(): + setattr(module, attr, value) + + def normalize_item(item: Union[str, dict, "QuantizationScheme"], layer_name: str) -> dict: + """Convert config entry into dict and validate keys.""" + if isinstance(item, str): + config = asdict(preset_name_to_scheme(item.upper())) + elif isinstance(item, QuantizationScheme): + config = asdict(item) + elif isinstance(item, dict): + invalid = set(item) - set(scheme_keys) + if invalid: + raise ValueError( + f"Invalid keys {invalid} in layer_config for '{layer_name}'. " f"Allowed keys: {scheme_keys}" + ) + config = dict(item) + else: + raise TypeError( + f"Unsupported type for layer_config[{layer_name}]: {type(item)}. " + f"Expected str, dict, or QuantizationScheme." + ) + # Clean up + config = {k: v for k, v in config.items() if v is not None} + config["fixed_by_user"] = True + return config + + # ---- main logic ---------------------------------------------- + scheme_keys = tuple(f.name for f in fields(QuantizationScheme)) + ("scale_dtype",) + layer_config = copy.deepcopy(layer_config) or {} + + # 1. fp_layers -> force 16 + for name in get_fp_layer_names(model, fp_layers): + layer_config[name] = { + "bits": 16, + "act_bits": 16, + "data_type": "float", + "act_data_type": "float", + "fixed_by_user": True, + } + + # 2. normalize + layer_config = {k: normalize_item(v, k) for k, v in layer_config.items()} + + # 3. infer missing bits + for cfg in layer_config.values(): + if "data_type" in cfg and "bits" not in cfg: + if (b := infer_bits_by_data_type(cfg["data_type"])) is not None: + cfg["bits"] = b + if "act_data_type" in cfg and "act_bits" not in cfg: + if (b := infer_bits_by_data_type(cfg["act_data_type"])) is not None: + cfg["act_bits"] = b + + # 4. fill defaults + if isinstance(default_scheme, str): + default_dict = asdict(preset_name_to_scheme(default_scheme.upper())) + else: + default_dict = asdict(default_scheme) + default_dict["scale_dtype"] = default_scale_dtype + for cfg in layer_config.values(): + for key in scheme_keys: + cfg.setdefault(key, copy.deepcopy(default_dict.get(key))) + + # 5. collect supported modules + gguf_name = get_gguf_scheme(default_scheme) + if gguf_name and torch.nn.Embedding not in supported_types: + supported_types = (*supported_types, torch.nn.Embedding) + + all_supported_layer_names, embedding_layer_names = [], [] + all_module_names = [] + for n, m in model.named_modules(): + all_module_names.append(n) + # cleanup stale attributes + for key in scheme_keys: + if hasattr(m, key): + delattr(m, key) + if type(m) not in supported_types and m.__class__.__name__ not in inner_supported_types: + continue + all_supported_layer_names.append(n) + if isinstance(m, torch.nn.Embedding): + embedding_layer_names.append(n) + + # 6. expand regex configs + for name in list(layer_config.keys()): + if name in all_supported_layer_names: + continue + if name in all_module_names: + m = get_module(model, name) + if len(list(m.children())) == 0 and type(m) not in supported_types: + logger.warning(f"{name} is not supported in current scheme, ignoring its setting in `layer_config`") + continue + + regex = re.compile(name) + matched = [ln for ln in all_supported_layer_names if regex.search(ln)] + if not matched: + raise ValueError(f"Invalid '{name}' in layer_config, no match found.") + val = layer_config.pop(name) + for match in matched: + layer_config[match] = val + + # 7. lm_head + lm_head_name = get_lm_head_name(model) + tie_word_embeddings = False + if hasattr(model, "config") and hasattr(model.config, "tie_word_embeddings"): + tie_word_embeddings = model.config.tie_word_embeddings + + if quant_lm_head and tie_word_embeddings: + quant_lm_head = False + logger.warning( + "reset `quant_lm_head` to false as quantizing " "lm_head with tied weights has not been supported currently" + ) + + if lm_head_name not in layer_config and quant_lm_head: + layer_config[lm_head_name] = copy.deepcopy(default_dict) + + # 8. enforce shape divisibility for int weight-only + if default_dict["data_type"] == "int" and default_dict["act_bits"] >= 16 and not gguf_name: + for n, m in model.named_modules(): + if type(m) in supported_types or m.__class__.__name__ in inner_supported_types: + if m.weight.shape[0] % 32 or m.weight.shape[1] % 32: + layer_config.setdefault(n, copy.deepcopy(default_dict)) + layer_config[n].update({"bits": 16, "data_type": "fp", "fixed_by_user": True}) + logger.warning_once(f"{n} skipped quantization (shape not divisible by 32).") + + # 9. block layers: mark as in_blocks=True + for name in get_layer_names_in_block(model, supported_types, quant_block_list, inner_supported_types): + if name not in layer_config: + layer_config[name] = copy.deepcopy(default_dict) + layer_config[name]["fixed_by_user"] = False + layer_config[name]["in_blocks"] = True + + # ---- restore: ensure missing in_blocks are set to False and compute flag ---- + has_qlayer_outside_block = False + for cfg in layer_config.values(): + if "in_blocks" not in cfg: + cfg["in_blocks"] = False + # mark layer outside block + if not cfg["in_blocks"] and check_to_quantized(cfg): + has_qlayer_outside_block = True + + # 10. GGUF handling + if not gguf_name: + dispatch_layer_config(layer_config) + return layer_config, has_qlayer_outside_block + + # embed + lm_head defaults for gguf + if lm_head_name not in layer_config and not tie_word_embeddings: + cfg = GGUF_INNER_CONFIG[GGUF_CONFIG[gguf_name.lower()]["lm_head"]] + cfg = {**cfg, "fixed_by_user": False, "scale_dtype": default_scale_dtype} + layer_config[lm_head_name] = cfg + has_qlayer_outside_block = True + for emd_name in embedding_layer_names: + if emd_name in layer_config: + continue + if not tie_word_embeddings: + cfg = GGUF_INNER_CONFIG[GGUF_CONFIG[gguf_name.lower()]["embedding"]] + else: + cfg = GGUF_INNER_CONFIG[GGUF_CONFIG[gguf_name.lower()]["lm_head"]] + cfg = {**cfg, "fixed_by_user": False, "scale_dtype": default_scale_dtype} + layer_config[emd_name] = cfg + + if enable_gguf_official_mixed: + model_type = ModelType.MMPROJ if is_mllm else ModelType.TEXT + layer_config, _ = get_layer_config_by_gguf_format(layer_config, gguf_name.lower(), model, model_type) + + dispatch_layer_config(layer_config) + return layer_config, has_qlayer_outside_block + + def check_diffusers_installed(): # pragma: no cover try: import diffusers # noqa: F401 @@ -2815,7 +3022,7 @@ def check_diffusers_installed(): # pragma: no cover exit(-1) -def is_diffusion_model(model_or_path: Union[str, object]): +def is_diffusion_model(model_or_path: Union[str, object]) -> bool: if isinstance(model_or_path, str): index_file = None if not os.path.isdir(model_or_path): diff --git a/auto_round/wrapper.py b/auto_round/wrapper.py index f6bffa94d..0a90c3965 100644 --- a/auto_round/wrapper.py +++ b/auto_round/wrapper.py @@ -150,7 +150,7 @@ def _init_tuning_params_and_quant_func(self): ) self._init_params("act_max_scale", p_dtype, (1), 1.0, not orig_layer.act_dynamic) - ## bias tuning + # Bias tuning if self.enable_norm_bias_tuning: self._init_params("bias_v", p_dtype, self.orig_layer.bias.shape, 0, True) from auto_round.data_type.int import quant_tensor_asym_wo_round diff --git a/test/test_cpu/test_autoround.py b/test/test_cpu/test_autoround.py index de5f7412a..cbd0583df 100644 --- a/test/test_cpu/test_autoround.py +++ b/test/test_cpu/test_autoround.py @@ -736,6 +736,7 @@ def test_invalid_layer_config(self): iters=1, layer_config=layer_config, ) + ar.quantize() def test_quant_lm_head(self): model_name = "/tf_dataset/auto_round/models/Qwen/Qwen3-8B" diff --git a/test/test_cuda/test_auto_scheme.py b/test/test_cuda/test_auto_scheme.py new file mode 100644 index 000000000..e11ce1a01 --- /dev/null +++ b/test/test_cuda/test_auto_scheme.py @@ -0,0 +1,184 @@ +import copy +import re +import shutil +import sys +import unittest + +from auto_round.testing_utils import multi_card + +sys.path.insert(0, "../..") + +from auto_round import AutoRound, AutoRoundConfig, AutoScheme +from auto_round.auto_schemes.utils import compute_avg_bits_for_model +from auto_round.eval.evaluation import simple_evaluate +from auto_round.utils import get_module + + +class TestAutoScheme(unittest.TestCase): + @classmethod + def setUpClass(self): + self.save_dir = "./saved" + self.tasks = "lambada_openai" + + @classmethod + def tearDownClass(self): + shutil.rmtree("./saved", ignore_errors=True) + shutil.rmtree("runs", ignore_errors=True) + + def test_gguf_export(self): + model_name = "/models/Qwen3-0.6B" + target_bits = 3 + scheme = AutoScheme(avg_bits=target_bits, options=("GGUF:Q2_K_S", "GGUF:Q4_K_M"), ignore_scale_zp_bits=True) + ar = AutoRound(model=model_name, scheme=scheme, iters=0) + ar.quantize_and_save(self.save_dir, format="gguf:q2_k_s") + shutil.rmtree("./saved", ignore_errors=True) + + def test_gguf(self): + model_name = "/models/Qwen3-8B" + target_bits = 3 + scheme = AutoScheme(avg_bits=target_bits, options=("GGUF:Q2_K_S", "GGUF:Q4_K_M"), ignore_scale_zp_bits=True) + ar = AutoRound(model=model_name, scheme=scheme, iters=0, nsamples=1, disable_opt_rtn=True) + model, layer_config = ar.quantize() + # self.assertLessEqual(layer_config["lm_head"]["bits"], 8) + avg_bits, _ = compute_avg_bits_for_model(model, ignore_scale_zp_bits=True) + print(avg_bits) + assert target_bits - 0.1 < avg_bits <= target_bits + 1e-3 + + def test_shared_layers(self): + model_name = "/models/opt-125m" + from transformers import AutoModelForCausalLM, AutoTokenizer + + model = AutoModelForCausalLM.from_pretrained(model_name) + shared_layers = [ + ["*.self_attn.k_proj", "v_proj", "q_proj", "out_proj"], + ("model.decoder.layers.6.fc1", "model.decoder.layers.6.fc2"), + ("fc1", "fc2"), + ] + from auto_round.auto_schemes.utils import parse_shared_layers + + res = parse_shared_layers(model, shared_layers) + self.assertEqual(len(res), 24) + assert [ + "model.decoder.layers.2.self_attn.out_proj", + "model.decoder.layers.2.self_attn.q_proj", + "model.decoder.layers.2.self_attn.v_proj", + ] in res + assert ["model.decoder.layers.6.fc1", "model.decoder.layers.6.fc2"] in res + assert ["model.decoder.layers.7.fc1", "model.decoder.layers.7.fc2"] in res + target_bits = 5.0 + scheme = AutoScheme(avg_bits=target_bits, options=("W4A16", "MXFP8"), shared_layers=shared_layers) + ar = AutoRound(model=model_name, scheme=scheme, iters=0, nsamples=1) + model, layer_config = ar.quantize() + avg_bits, _ = compute_avg_bits_for_model(model) + for names in res: + bits = [] + for name in names: + module = get_module(model, name) + if hasattr(module, "orig_layer"): + bits.append(module.orig_layer.bits) + else: + bits.append(module.bits) + bits = set(bits) + self.assertEqual(len(bits), 1) + print(avg_bits) + assert target_bits - 0.1 < avg_bits <= target_bits + 1e-3 + + @multi_card + def test_multi_card(self): + model_name = "/models/Qwen3-8B" + target_bits = 5.254 + for device_map in ["auto", "0,1", "0", None]: + scheme = AutoScheme(avg_bits=target_bits, options=("NVFP4")) + ar = AutoRound(model=model_name, scheme=scheme, iters=0, nsamples=1, device_map=device_map) + model, layer_config = ar.quantize() + avg_bits, _ = compute_avg_bits_for_model(model) + print(avg_bits) + assert target_bits - 0.1 < avg_bits <= target_bits + 1e-3 + + @multi_card + def test_dict_device_map(self): # TODO rtn mode has bug + model_name = "/models/Qwen3-8B" + target_bits = 8.755 + device_map = {"up_proj": 0, "down_proj": 1} + + scheme = AutoScheme(avg_bits=target_bits, options=("MXFP8")) + ar = AutoRound(model=model_name, scheme=scheme, iters=0, nsamples=1, device_map=device_map) + model, layer_config = ar.quantize() + avg_bits, _ = compute_avg_bits_for_model(model) + print(avg_bits) + assert target_bits - 0.1 < avg_bits <= target_bits + 1e-3 + + def test_min_target_bits(self): + model_name = "/models/opt-125m" + target_bits = 4.644 + scheme = AutoScheme(avg_bits=target_bits, options=("MXFP4", "W8A16")) + ar = AutoRound(model=model_name, scheme=scheme, iters=0, nsamples=1) + model, layer_config = ar.quantize() + # self.assertLessEqual(layer_config["lm_head"]["bits"], 8) + avg_bits, _ = compute_avg_bits_for_model(model) + print(avg_bits) + assert target_bits - 0.1 < avg_bits <= target_bits + 1e-3 + + def test_max_target_bits(self): + model_name = "/models/opt-125m" + target_bits = 8.211 + scheme = AutoScheme(avg_bits=target_bits, options=("MXFP4", "W8A16")) + ar = AutoRound(model=model_name, scheme=scheme, iters=0, nsamples=1) + model, layer_config = ar.quantize() + # self.assertLessEqual(layer_config["lm_head"]["bits"], 8) + avg_bits, _ = compute_avg_bits_for_model(model) + print(avg_bits) + assert target_bits - 0.1 < avg_bits <= target_bits + 1e-3 + + def test_patch_scheme(self): + model_name = "/models/opt-125m" + target_bits = 5 + scheme = AutoScheme(avg_bits=target_bits, options=("MXFP4", "W8A16")) + ar = AutoRound(model=model_name, scheme=scheme, iters=0, nsamples=1, group_size=32) + model, layer_config = ar.quantize() + for n, m in model.named_modules(): + if hasattr(m, "group_size"): + self.assertEqual(m.group_size, 32) + avg_bits, _ = compute_avg_bits_for_model(model) + print(avg_bits) + assert target_bits - 0.1 < avg_bits <= target_bits + 1e-3 + + def test_layer_config(self): + target_bits = 3.0 + model_name = "/models/opt-125m" + scheme = AutoScheme(avg_bits=3, options=("W2A16", "W4A16", "BF16")) + user_layer_config = {"model.decoder.layers.10.fc1": {"bits": 8, "group_size": 32, "sym": False}} + ar = AutoRound(model=model_name, scheme=scheme, iters=0, nsamples=1, layer_config=user_layer_config) + model, layer_config = ar.quantize() + self.assertEqual(layer_config["model.decoder.layers.10.fc1"]["bits"], 8) + self.assertEqual(layer_config["model.decoder.layers.10.fc1"]["sym"], False) + self.assertEqual(layer_config["model.decoder.layers.10.fc1"]["group_size"], 32) + layer = get_module(model, "model.decoder.layers.10.fc1") + self.assertEqual(layer.bits, 8) + self.assertEqual(layer.sym, False) + self.assertEqual(layer.group_size, 32) + avg_bits, _ = compute_avg_bits_for_model(model) + print(avg_bits) + assert target_bits - 0.1 < avg_bits <= target_bits + 1e-3 + + def test_lm_head_and_mix_dtype(self): + model_name = "/models/Qwen3-8B" + target_bits = 6 + scheme = AutoScheme(avg_bits=target_bits, options=("MXFP4", "W8A16")) + ar = AutoRound(model=model_name, scheme=scheme, iters=0, nsamples=1, quant_lm_head=True) + model, layer_config = ar.quantize() + self.assertLessEqual(layer_config["lm_head"]["bits"], 8) + avg_bits, _ = compute_avg_bits_for_model(model) + print(avg_bits) + assert target_bits - 0.1 < avg_bits <= target_bits + 1e-3 + + def test_auto_scheme_export(self): + model_name = "/models/opt-125m" + scheme = AutoScheme(avg_bits=3, options=("W2A16", "W4A16", "BF16")) + ar = AutoRound(model=model_name, scheme=scheme) + ar.quantize_and_save(self.save_dir) + model_args = f"pretrained={self.save_dir}" + result = simple_evaluate(model="hf", model_args=model_args, tasks="lambada_openai", batch_size="auto") + print(result["results"]["lambada_openai"]["acc,none"]) + self.assertGreater(result["results"]["lambada_openai"]["acc,none"], 0.25) + shutil.rmtree(self.save_dir, ignore_errors=True)