Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
6ffcf60
try to enable auto_scheme API
wenhuach21 Sep 25, 2025
5d80825
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 25, 2025
a4ef495
update a little
wenhuach21 Sep 25, 2025
4173c3e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 25, 2025
87e9454
update a little
wenhuach21 Sep 25, 2025
f86eedb
Merge branch 'main' into auto_scheme
wenhuach21 Sep 25, 2025
242d1ee
try to refine parse layer config code
wenhuach21 Sep 25, 2025
4fc6b64
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 25, 2025
63de904
Merge branch 'main' into auto_scheme
wenhuach21 Sep 26, 2025
bb4d4ca
Merge branch 'main' into auto_scheme
wenhuach21 Sep 26, 2025
7f76db2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 26, 2025
ae8837b
fix
wenhuach21 Sep 26, 2025
44ca92d
Merge branch 'auto_scheme' of https://github.com/intel/auto-round int…
wenhuach21 Sep 26, 2025
531224d
fix
wenhuach21 Sep 26, 2025
c9fa408
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 26, 2025
6453200
fix
wenhuach21 Sep 26, 2025
5b2dd60
Merge branch 'auto_scheme' of https://github.com/intel/auto-round int…
wenhuach21 Sep 26, 2025
3811010
tmp_change
wenhuach21 Sep 26, 2025
4de7b08
commit
wenhuach21 Sep 26, 2025
a9f0e44
commit
wenhuach21 Sep 26, 2025
59a9f5d
update a little
wenhuach21 Sep 26, 2025
1b7e911
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 26, 2025
e068049
fix
wenhuach21 Sep 26, 2025
1b84bf2
Merge branch 'auto_scheme' of https://github.com/intel/auto-round int…
wenhuach21 Sep 26, 2025
0357c0b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 26, 2025
7c034bd
Merge branch 'main' into auto_scheme
wenhuach21 Sep 26, 2025
602421c
merge autoscheme to scheme
wenhuach21 Sep 26, 2025
091c5ad
refine layer_config code
wenhuach21 Sep 29, 2025
90b6fa1
Merge branch 'main' into auto_scheme
wenhuach21 Sep 29, 2025
f027801
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 29, 2025
c6b78c6
tiny change
wenhuach21 Sep 29, 2025
1b9f24e
tiny fix
wenhuach21 Sep 29, 2025
2c0075a
tmp change
wenhuach21 Sep 29, 2025
97198f0
tmp change
wenhuach21 Sep 29, 2025
27b4b4d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 29, 2025
2d3095a
update
wenhuach21 Sep 29, 2025
35a298b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 29, 2025
4a594cd
fix
wenhuach21 Sep 29, 2025
dcd08d6
fix uts, still one left
wenhuach21 Sep 30, 2025
9172264
fix gguf issue
wenhuach21 Sep 30, 2025
1d9e593
Merge branch 'main' into auto_scheme
wenhuach21 Sep 30, 2025
f98092c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 30, 2025
033d1f6
update a little
wenhuach21 Sep 30, 2025
8ae1dfa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 30, 2025
a3756ce
fix some issues
wenhuach21 Oct 9, 2025
2f93471
fix some issues
wenhuach21 Oct 9, 2025
e0c3d4b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 9, 2025
0130932
Merge branch 'main' into auto_scheme
wenhuach21 Oct 9, 2025
6e04d10
update
wenhuach21 Oct 9, 2025
04c604c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 9, 2025
3880038
Merge branch 'main' into auto_scheme
wenhuach21 Oct 9, 2025
87d3694
fix one bug
wenhuach21 Oct 9, 2025
fa85d42
Merge branch 'main' into auto_scheme
wenhuach21 Oct 9, 2025
3855c8f
fix
wenhuach21 Oct 10, 2025
d3e28c2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 10, 2025
706df03
Merge branch 'main' into auto_scheme
wenhuach21 Oct 10, 2025
2d557d0
set up the first version, there are many details to be handled
wenhuach21 Oct 10, 2025
567ebb8
Merge branch 'auto_scheme' of https://github.com/intel/auto-round int…
wenhuach21 Oct 10, 2025
cedad47
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 10, 2025
0c3a0e2
fix one bug
wenhuach21 Oct 10, 2025
cced6d8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 10, 2025
58d5ae2
uncomment ut
wenhuach21 Oct 10, 2025
e9bcd4a
Merge branch 'auto_scheme' of https://github.com/intel/auto-round int…
wenhuach21 Oct 10, 2025
ea489c3
rename functions
wenhuach21 Oct 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions auto_round/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@
# limitations under the License.
from auto_round.autoround import AutoRound

# support for old api
from auto_round.autoround import AutoRoundLLM, AutoRoundMLLM, AutoRoundAdam
from auto_round.schemes import QuantizationScheme
from auto_round.schemes import QuantizationScheme, AutoScheme
from auto_round.utils import LazyImport


Expand Down
11 changes: 10 additions & 1 deletion auto_round/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def __init__(self, *args, **kwargs):

self.add_argument(
"--scale_dtype",
default="fp16",
default=None,
choices=["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"],
help="scale data type to use for quantization",
)
Expand Down Expand Up @@ -470,6 +470,14 @@ def tune(args):
extra_config.scheme_config = scheme_config
extra_config.mllm_config = mllm_config

layer_config = {}
# from auto_round.auto_schemes.haha import get_mixed_config_layer_config
# layer_config = {}
# best_path = get_mixed_config_layer_config(model_name, target_bits=3)
# for item in best_path:
# layer_config[item[0]] = {}
# layer_config[item[0]]["bits"] = item[1]

autoround: BaseCompressor = AutoRound(
model=model_name,
scheme=scheme,
Expand All @@ -486,6 +494,7 @@ def tune(args):
not_use_best_mse=args.not_use_best_mse,
enable_adam=args.adam,
extra_config=extra_config,
layer_config=layer_config,
)

model_name = args.model.rstrip("/")
Expand Down
42 changes: 42 additions & 0 deletions auto_round/auto_schemes/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright (c) 2025 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

AUTO_SCHEMES_METHODS = {}


def register_scheme_methods(names):
"""Class decorator to register a mixed precision algorithm to the registry.

Decorator function used before a Pattern subclass.

Args:
names: A string. Define the export type.

Returns:
cls: The class of register.
"""

def register(alg):
if isinstance(names, (tuple, list)):
for name in names:
AUTO_SCHEMES_METHODS[name] = alg
else:
AUTO_SCHEMES_METHODS[names] = alg

return alg

return register


import auto_round.auto_schemes.haha # pylint: disable=E0611
84 changes: 84 additions & 0 deletions auto_round/auto_schemes/gen_auto_scheme.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright (c) 2025 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import asdict
from typing import Iterable

import torch

from auto_round import AutoScheme
from auto_round.auto_schemes import AUTO_SCHEMES_METHODS
from auto_round.auto_schemes.utils import compute_avg_bits_for_scheme
from auto_round.logger import logger


class GenScheme:
"""Generate and validate quantization schemes for model layers."""

def __init__(
self,
auto_scheme: AutoScheme, # TODO support shared layer
model: torch.nn.Module,
quant_layer_names: Iterable[str],
fixed_layer_scheme: dict[str, dict],
dataset: str = "pile-10k", # TODO use auto-round dataset
tokenizer=None,
):
self.auto_scheme = auto_scheme
self.model = model
self.tokenizer = tokenizer
self.quant_layer_names = quant_layer_names
self.fixed_layer_scheme = fixed_layer_scheme
self.dataset = dataset

self._check_configs()

def _check_configs(self) -> None:
"""Validate auto_scheme configuration and ensure avg_bits target is valid."""
if isinstance(self.model, torch.nn.Module) and self.tokenizer is None:
raise ValueError("tokenizer must not be None if model is nn.Module")

if not isinstance(self.dataset, str):
raise TypeError("`dataset` must be a string, got {type(self.dataset).__name__}.")

min_avg_bit, max_avg_bit = self.compute_avg_bit_range()
target = self.auto_scheme.avg_bits

logger.info("Average bits range: [%.3f, %.3f], target = %.3f", min_avg_bit, max_avg_bit, target)

if not (min_avg_bit <= target <= max_avg_bit):
raise ValueError(
f"Target avg_bits={target:.3f} is outside the valid range " f"[{min_avg_bit:.3f}, {max_avg_bit:.3f}]."
)

def get_layer_config(self):
method_name = self.auto_scheme.method
method_func = AUTO_SCHEMES_METHODS[method_name]
layer_config = method_func(
self.auto_scheme, self.model, self.quant_layer_names, self.fixed_layer_scheme, self.dataset, self.tokenizer
)
return layer_config

def compute_avg_bit_range(self) -> tuple[float, float]:
"""Compute the min and max average bitwidths among candidate quantization options."""
avg_bits = [
compute_avg_bits_for_scheme(
self.model,
self.quant_layer_names,
self.fixed_layer_scheme,
option,
self.auto_scheme.ignore_scale_zp_bits,
)[0]
for option in self.auto_scheme.options
]
return min(avg_bits), max(avg_bits)
184 changes: 184 additions & 0 deletions auto_round/auto_schemes/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# Copyright (c) 2025 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import asdict, fields
from typing import Iterable, Union

import torch

from auto_round.low_cpu_mem import get_module
from auto_round.schemes import QuantizationScheme, preset_name_to_scheme
from auto_round.utils import check_to_quantized, get_layer_features


def apply_quant_scheme(
model: torch.nn.Module,
quant_layer_names: Iterable[str],
fixed_layer_scheme: dict[str, dict],
scheme: Union[str, dict], # TODO add scale_dtype
) -> None:
"""Apply a quantization scheme to each quantized layer.

Args:
model: The model whose layers are to be updated.
scheme: The scheme preset name or dictionary to apply.
quant_layer_names: Iterable of layer names to quantize.
fixed_layer_scheme: Dictionary of fixed per-layer quantization schemes.
"""
for name in quant_layer_names:
layer_scheme = fixed_layer_scheme.get(name, scheme)
if isinstance(layer_scheme, str):
layer_scheme = asdict(preset_name_to_scheme(layer_scheme))

module = get_module(model, name)
for key, value in layer_scheme.items():
setattr(module, key, value)


def remove_quant_scheme(
model: torch.nn.Module,
) -> None:
"""Remove attributes corresponding to the applied quantization scheme.

Args:
model: The model whose layers are to be cleared.
"""
scheme_keys = [f.name for f in fields(QuantizationScheme)] + ["scale_dtype"]
for n, m in model.named_modules():
for key in scheme_keys:
if hasattr(m, key):
delattr(m, key)


def compute_avg_bits_for_scheme(
model: torch.nn.Module,
quant_layer_names: Iterable[str],
fixed_layer_scheme: dict[str, dict],
scheme: Union[str, dict, None] = None,
ignore_scale_zp_bits: bool = False,
) -> tuple[float, float]:
"""Compute the average and total bit usage for the given quantization scheme.

Args:
model: The model to analyze.
quant_layer_names: Iterable of layer names to include.
fixed_layer_scheme: Dictionary of fixed per-layer quantization schemes.
scheme: Optional scheme to temporarily apply before measuring.
ignore_scale_zp_bits: If True, ignores overhead from scale and zero-points.

Returns:
A tuple (avg_bits, total_quantized_bits):
avg_bits: Average bitwidth per parameter.
total_quantized_bits: Total quantized bit count.
"""
if scheme is not None:
apply_quant_scheme(model, quant_layer_names, fixed_layer_scheme, scheme)

total_params = 0
total_quantized_bits = 0

for name in quant_layer_names:
module = get_module(model, name)
if not hasattr(module, "weight"):
continue
total_params += module.weight.numel()
layer_bits, _ = compute_layer_bits(module, ignore_scale_zp_bits)
total_quantized_bits += layer_bits

avg_bits = float(total_quantized_bits) / total_params

if scheme is not None:
remove_quant_scheme(model)

return avg_bits, total_quantized_bits


def compute_avg_bits_for_model(model: torch.nn.Module, ignore_scale_zp_bits: bool = False):
"""Compute the average and total bit usage for the entire model.

Args:
model: The model to analyze.
ignore_scale_zp_bits: If True, ignores overhead from scale and zero-points.
if scheme is not None:
apply_quant_scheme(model, quant_layer_names, fixed_layer_scheme, scheme)
"""

total_params = 0
total_quantized_bits = 0

for n, module in model.named_modules():
if not hasattr(module, "bits"):
continue
if not hasattr(module, "weight"):
continue
total_params += module.weight.numel()
layer_bits, _ = compute_layer_bits(module, ignore_scale_zp_bits)
total_quantized_bits += layer_bits

avg_bits = float(total_quantized_bits) / total_params

return avg_bits, total_quantized_bits


def compute_layer_bits(
layer: torch.nn.Module,
ignore_scale_zp_bits: bool = False,
) -> tuple[int, float]:
"""Compute total and average bitwidth for a single quantized layer.

Args:
layer: A PyTorch layer with quantization attributes.
ignore_scale_zp_bits: Whether to ignore scale/zero-point overhead.

Returns:
A tuple (total_bits, avg_bits) representing bit usage.
"""
weight = layer.weight
n_param = weight.numel()
weight_bits = getattr(layer, "bits", 16)
group_size = getattr(layer, "group_size", 128)
super_group_size = getattr(layer, "super_group_size", None)
super_weight_bits = getattr(layer, "super_bits", None)

# Unquantized layer or ignoring scale/zp overhead
if weight_bits >= 16 or ignore_scale_zp_bits:
if super_weight_bits is not None: # reset gguf 16 bits to 32 bits, TODO gguf q4_0, q4_1 may have bug
return 32 * n_param, 32
return weight_bits * n_param, 16.0

in_features, out_features = get_layer_features(layer)

# Determine number of groups based on group size
if group_size > 0:
n_group = out_features * (in_features + group_size - 1) // group_size
elif group_size == 0:
n_group = 1
elif group_size == -1:
n_group = out_features
else:
raise ValueError(f"Invalid group_size {group_size}")

# Compute auxiliary bits (scales, zero-points, or double quantization)
aux_total_bits = 0
if not super_group_size:
scale_bits = 16
zp_bits = weight_bits
aux_total_bits = n_group * (scale_bits + zp_bits)
else:
aux_total_bits += n_group * super_weight_bits * 2
n_super_group = (n_group + super_group_size - 1) // super_group_size
aux_total_bits += n_super_group * 32 * 2 # 32-bit scale and min_v

total_bits = weight_bits * n_param + aux_total_bits
avg_bits = total_bits / n_param
return total_bits, avg_bits
5 changes: 2 additions & 3 deletions auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
MLLMCompressor,
)
from auto_round.logger import deprecated, logger
from auto_round.schemes import QuantizationScheme
from auto_round.schemes import AutoScheme, QuantizationScheme
from auto_round.utils import is_mllm_model


Expand Down Expand Up @@ -63,7 +63,7 @@ def __new__(
cls,
model: Union[torch.nn.Module, str],
tokenizer=None,
scheme: Union[str, dict, QuantizationScheme] = "W4A16",
scheme: Union[str, dict, QuantizationScheme, AutoScheme] = "W4A16",
layer_config: dict[str, Union[str, dict, QuantizationScheme]] = None,
dataset: Union[str, list, tuple, torch.utils.data.DataLoader] = "NeelNanda/pile-10k",
iters: int = 200,
Expand All @@ -77,7 +77,6 @@ def __new__(
seed: int = 42,
# for adam
enable_adam: bool = False,
# for MLLM
extra_config: ExtraConfig = None,
**kwargs,
) -> BaseCompressor:
Expand Down
Loading
Loading