Skip to content

Commit

Permalink
Support for aliased scheme settings in quant config (#40)
Browse files Browse the repository at this point in the history
  • Loading branch information
bfineran authored May 22, 2024
1 parent 2c64578 commit 7572c7b
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 5 deletions.
25 changes: 21 additions & 4 deletions src/compressed_tensors/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@
# limitations under the License.

from enum import Enum
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union

from compressed_tensors.base import QUANTIZATION_CONFIG_NAME
from compressed_tensors.config import CompressionFormat
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
from compressed_tensors.quantization.quant_scheme import (
QuantizationScheme,
preset_name_to_scheme,
)
from compressed_tensors.quantization.utils import (
calculate_compression_ratio,
is_module_quantized,
Expand Down Expand Up @@ -105,7 +108,8 @@ class QuantizationConfig(BaseModel):
mapped to a QuantizationScheme in config_groups.
:param config_groups: dict of QuantizationSchemes specifying the quantization
settings for each quantized layer
settings for each quantized layer. A group could also be a reference to
a predefined scheme name, mapped to a list of its target layers/classes
:param quant_method: a constant used to differentiate sparseML quantization from
other quantization configs
:param format: specifies how the quantized model is stored on disk
Expand All @@ -117,13 +121,26 @@ class QuantizationConfig(BaseModel):
are not quantized even if they match up with a target in config_groups
"""

config_groups: Dict[str, QuantizationScheme]
config_groups: Dict[str, Union[QuantizationScheme, List[str]]]
quant_method: str = "sparseml"
format: str = "fakequant"
quantization_status: QuantizationStatus = QuantizationStatus.INITIALIZED
global_compression_ratio: Optional[float] = None
ignore: Optional[List[str]] = Field(default_factory=list)

def model_post_init(self, __context):
"""
updates any quantization schemes defined as presets to be fully loaded
schemes
"""
for group_name, targets_or_scheme in self.config_groups.items():
if isinstance(targets_or_scheme, QuantizationScheme):
continue # scheme already defined
self.config_groups[group_name] = preset_name_to_scheme(
name=group_name,
targets=targets_or_scheme,
)

@staticmethod
def from_model_config(model_name_or_path) -> "QuantizationConfig":
"""
Expand Down
45 changes: 44 additions & 1 deletion src/compressed_tensors/quantization/quant_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from copy import deepcopy
from typing import List, Optional

from compressed_tensors.quantization.quant_args import QuantizationArgs
from pydantic import BaseModel


__all__ = ["QuantizationScheme"]
__all__ = [
"QuantizationScheme",
"preset_name_to_scheme",
]


class QuantizationScheme(BaseModel):
Expand Down Expand Up @@ -65,3 +69,42 @@ def default_scheme(
input_activations=input_activations,
output_activations=output_activations,
)


"""
Pre-Set Quantization Scheme Args
"""


def preset_name_to_scheme(name: str, targets: List[str]) -> QuantizationScheme:
"""
:param name: preset quantization settings name. must exist in upper case in
PRESET_SCHEMES
:param targets: list of quantization targets to be passed to the Scheme
:return: new QuantizationScheme for a given name with the given targets
"""
name = name.upper()

if name not in PRESET_SCHEMES:
raise KeyError(
f"Unknown preset scheme name {name}, "
f"available names: {list(PRESET_SCHEMES.keys())}"
)

scheme_args = deepcopy(PRESET_SCHEMES[name]) # deepcopy to avoid args references
return QuantizationScheme(
targets=targets,
**scheme_args,
)


W8A8 = dict(
weights=QuantizationArgs(), input_activations=QuantizationArgs(symmetric=False)
)

W4A16 = dict(weights=QuantizationArgs(num_bits=4, symmetric=False))

PRESET_SCHEMES = {
"W8A8": W8A8,
"W4A16": W4A16,
}
13 changes: 13 additions & 0 deletions tests/quantization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
32 changes: 32 additions & 0 deletions tests/quantization/test_quant_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
from compressed_tensors.quantization import QuantizationConfig, QuantizationScheme


@pytest.mark.parametrize(
"scheme_name",
[
"W8A8",
"W4A16",
],
)
def test_load_scheme_from_preset(scheme_name: str):
targets = ["Linear"]
config = QuantizationConfig(config_groups={scheme_name: targets})

assert scheme_name in config.config_groups
assert isinstance(config.config_groups[scheme_name], QuantizationScheme)
assert config.config_groups[scheme_name].targets == targets

0 comments on commit 7572c7b

Please sign in to comment.