Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for aliased scheme settings in quant config #40

Merged
merged 3 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions src/compressed_tensors/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@
# limitations under the License.

from enum import Enum
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union

from compressed_tensors.base import QUANTIZATION_CONFIG_NAME
from compressed_tensors.config import CompressionFormat
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
from compressed_tensors.quantization.quant_scheme import (
QuantizationScheme,
preset_name_to_scheme,
)
from compressed_tensors.quantization.utils import (
calculate_compression_ratio,
is_module_quantized,
Expand Down Expand Up @@ -105,7 +108,8 @@ class QuantizationConfig(BaseModel):
mapped to a QuantizationScheme in config_groups.

:param config_groups: dict of QuantizationSchemes specifying the quantization
settings for each quantized layer
settings for each quantized layer. A group could also be a reference to
a predefined scheme name, mapped to a list of its target layers/classes
:param quant_method: a constant used to differentiate sparseML quantization from
other quantization configs
:param format: specifies how the quantized model is stored on disk
Expand All @@ -117,13 +121,26 @@ class QuantizationConfig(BaseModel):
are not quantized even if they match up with a target in config_groups
"""

config_groups: Dict[str, QuantizationScheme]
config_groups: Dict[str, Union[QuantizationScheme, List[str]]]
quant_method: str = "sparseml"
format: str = "fakequant"
quantization_status: QuantizationStatus = QuantizationStatus.INITIALIZED
global_compression_ratio: Optional[float] = None
ignore: Optional[List[str]] = Field(default_factory=list)

def model_post_init(self, __context):
"""
updates any quantization schemes defined as presets to be fully loaded
schemes
"""
for group_name, targets_or_scheme in self.config_groups.items():
if isinstance(targets_or_scheme, QuantizationScheme):
continue # scheme already defined
self.config_groups[group_name] = preset_name_to_scheme(
name=group_name,
targets=targets_or_scheme,
)

@staticmethod
def from_model_config(model_name_or_path) -> "QuantizationConfig":
"""
Expand Down
45 changes: 44 additions & 1 deletion src/compressed_tensors/quantization/quant_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from copy import deepcopy
from typing import List, Optional

from compressed_tensors.quantization.quant_args import QuantizationArgs
from pydantic import BaseModel


__all__ = ["QuantizationScheme"]
__all__ = [
"QuantizationScheme",
"preset_name_to_scheme",
]


class QuantizationScheme(BaseModel):
Expand Down Expand Up @@ -65,3 +69,42 @@ def default_scheme(
input_activations=input_activations,
output_activations=output_activations,
)


"""
Pre-Set Quantization Scheme Args
"""


def preset_name_to_scheme(name: str, targets: List[str]) -> QuantizationScheme:
"""
:param name: preset quantization settings name. must exist in upper case in
PRESET_SCHEMES
:param targets: list of quantization targets to be passed to the Scheme
:return: new QuantizationScheme for a given name with the given targets
"""
name = name.upper()

if name not in PRESET_SCHEMES:
raise KeyError(
f"Unknown preset scheme name {name}, "
f"available names: {list(PRESET_SCHEMES.keys())}"
)

scheme_args = deepcopy(PRESET_SCHEMES[name]) # deepcopy to avoid args references
return QuantizationScheme(
targets=targets,
**scheme_args,
)


W8A8 = dict(
weights=QuantizationArgs(), input_activations=QuantizationArgs(symmetric=False)
)

W4A16 = dict(weights=QuantizationArgs(num_bits=4, symmetric=False))

PRESET_SCHEMES = {
"W8A8": W8A8,
"W4A16": W4A16,
}
13 changes: 13 additions & 0 deletions tests/quantization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
32 changes: 32 additions & 0 deletions tests/quantization/test_quant_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
from compressed_tensors.quantization import QuantizationConfig, QuantizationScheme


@pytest.mark.parametrize(
"scheme_name",
[
"W8A8",
"W4A16",
],
)
def test_load_scheme_from_preset(scheme_name: str):
targets = ["Linear"]
config = QuantizationConfig(config_groups={scheme_name: targets})

assert scheme_name in config.config_groups
assert isinstance(config.config_groups[scheme_name], QuantizationScheme)
assert config.config_groups[scheme_name].targets == targets
Loading