-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Loading status checks…
Merge pull request #3 from neuralmagic/sa/quant_config
Define BaseModels for Quantization
- Loading branch information
Showing
8 changed files
with
355 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# flake8: noqa | ||
from .quant_args import * | ||
from .quant_config import * | ||
from .quant_scheme import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from enum import Enum | ||
from typing import Optional | ||
|
||
from pydantic import BaseModel | ||
|
||
|
||
__all__ = ["QuantizationType", "QuantizationStrategy", "QuantizationArgs"] | ||
|
||
|
||
class QuantizationType(Enum): | ||
""" | ||
Enum storing quantization type options | ||
""" | ||
|
||
INT = "int" | ||
FLOAT = "float" | ||
|
||
|
||
class QuantizationStrategy(Enum): | ||
""" | ||
Enum storing quantization strategy options | ||
""" | ||
|
||
TENSOR = "tensor" | ||
CHANNEL = "channel" | ||
GROUP = "group" | ||
BLOCK = "block" | ||
|
||
|
||
class QuantizationArgs(BaseModel): | ||
""" | ||
User facing arguments used to define a quantization config for weights or | ||
activations | ||
:param num_bits: quantization bit depth | ||
:param type: dtype to quantized to, either int or float | ||
:param symmetric: whether or not quantization scale is symmetric about zero-point | ||
:param strategy: string id determining the scope of scale/zero-point to apply | ||
:param group_size: group length to use for the group strategy | ||
:param block_structure: 2d block structure to use for the block strategy, must be | ||
of the format "2x4", "8x16", etc. | ||
""" | ||
|
||
num_bits: int = 8 | ||
type: QuantizationType = QuantizationType.INT | ||
symmetric: bool = True | ||
strategy: QuantizationStrategy = QuantizationStrategy.TENSOR | ||
group_size: Optional[int] = None | ||
block_structure: Optional[str] = None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from enum import Enum | ||
from typing import Dict, List, Optional | ||
|
||
from pydantic import BaseModel | ||
from sparsetensors.quantization.quant_scheme import QuantizationScheme | ||
|
||
|
||
__all__ = ["QuantizationStatus", "QuantizationConfig"] | ||
|
||
|
||
class QuantizationStatus(Enum): | ||
""" | ||
Enum storing the different states a quantized layer can be in | ||
Initialized: scale, zero points and observers have been attached to the layer but | ||
are set to dummy values (not yet calibrated) | ||
Calibration: scale and zero points have been calibrated through OBCQ or similar | ||
algorithm, observers are still attached | ||
Frozen: scale and zero points are finalized, observers have been deleted, weights | ||
are still in their original precision | ||
Compressed: weights have been converted to their target type or compressed to | ||
their closed approximation | ||
""" | ||
|
||
INITIALIZED = "initialized" | ||
CALIBRATION = "calibration" | ||
FROZEN = "frozen" | ||
COMPRESSED = "compressed" | ||
|
||
|
||
class QuantizationConfig(BaseModel): | ||
""" | ||
Full configuration specifying how a model is quantized. Each quantized layer is | ||
mapped to a QuantizationScheme in config_groups. | ||
:param config_groups: dict of QuantizationSchemes specifying the quantization | ||
settings for each quantized layer | ||
:param quant_method: a constant used to differentiate sparseML quantization from | ||
other quantization configs | ||
:param format: specifies how the quantized model is stored on disk | ||
:quantization_status: specifies the current status of all quantized layers. It is | ||
assumed all layers are in the same state. | ||
:global_compression_ratio: optional informational config to report the model | ||
compression ratio acheived by the quantization config | ||
:ignore: optional list of layers to ignore from config_groups. Layers in this list | ||
are not quantized even if they match up with a target in config_groups | ||
""" | ||
|
||
config_groups: Dict[str, QuantizationScheme] | ||
quant_method: str = "sparseml" | ||
format: str = "fakequant" | ||
quantization_status: QuantizationStatus = QuantizationStatus.INITIALIZED | ||
global_compression_ratio: Optional[float] = None | ||
ignore: Optional[List[str]] = None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import List, Optional | ||
|
||
from pydantic import BaseModel | ||
from sparsetensors.quantization.quant_args import QuantizationArgs | ||
|
||
|
||
__all__ = ["QuantizationScheme"] | ||
|
||
|
||
class QuantizationScheme(BaseModel): | ||
""" | ||
Set of QuantizationArgs defining how the weights, inputs and outputs of target list | ||
of modules should be quantized | ||
:param targets: list of modules to apply the QuantizationArgs to, can be layer | ||
names, layer types or a regular expression | ||
:param weights: quantization config for layer weights | ||
:param input_activations: quantization config for layer inputs | ||
:param output_activations: quantization config for layer outputs | ||
""" | ||
|
||
targets: List[str] | ||
weights: Optional[QuantizationArgs] = None | ||
input_activations: Optional[QuantizationArgs] = None | ||
output_activations: Optional[QuantizationArgs] = None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import pytest | ||
from pydantic import ValidationError | ||
from sparsetensors.quantization import ( | ||
QuantizationArgs, | ||
QuantizationStrategy, | ||
QuantizationType, | ||
) | ||
|
||
|
||
def test_defaults(): | ||
default = QuantizationArgs() | ||
|
||
assert default.num_bits == 8 | ||
assert default.type == QuantizationType.INT | ||
assert default.symmetric | ||
assert default.strategy == QuantizationStrategy.TENSOR | ||
assert default.group_size is None | ||
assert default.block_structure is None | ||
|
||
|
||
def test_group(): | ||
kwargs = {"strategy": "group", "group_size": 128} | ||
|
||
group = QuantizationArgs(**kwargs) | ||
assert group.strategy == QuantizationStrategy.GROUP | ||
assert group.group_size == kwargs["group_size"] | ||
|
||
|
||
def test_block(): | ||
kwargs = {"strategy": "block", "block_structure": "2x4"} | ||
|
||
block = QuantizationArgs(**kwargs) | ||
assert block.strategy == QuantizationStrategy.BLOCK | ||
assert block.block_structure == kwargs["block_structure"] | ||
|
||
|
||
def test_invalid(): | ||
with pytest.raises(ValidationError): | ||
_ = QuantizationArgs(type="invalid") | ||
with pytest.raises(ValidationError): | ||
_ = QuantizationArgs(strategy="invalid") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
import pytest | ||
from pydantic import ValidationError | ||
from sparsetensors.quantization import ( | ||
QuantizationConfig, | ||
QuantizationScheme, | ||
QuantizationStatus, | ||
) | ||
|
||
|
||
def test_basic_config(): | ||
config_groups = {"group_1": QuantizationScheme(targets=[])} | ||
config = QuantizationConfig(config_groups=config_groups) | ||
|
||
assert config.config_groups == config_groups | ||
assert config.quant_method == "sparseml" | ||
assert config.format == "fakequant" | ||
assert config.quantization_status == QuantizationStatus.INITIALIZED | ||
assert config.global_compression_ratio is None | ||
assert config.ignore is None | ||
|
||
|
||
def test_full_config(): | ||
config_groups = { | ||
"group_1": QuantizationScheme(targets=[]), | ||
"group_2": QuantizationScheme(targets=[]), | ||
} | ||
global_compression_ratio = 3.5 | ||
ignore = ["model.layers.0"] | ||
quantization_status = "compressed" | ||
|
||
config = QuantizationConfig( | ||
config_groups=config_groups, | ||
global_compression_ratio=global_compression_ratio, | ||
ignore=ignore, | ||
quantization_status=quantization_status, | ||
) | ||
assert config.config_groups == config_groups | ||
assert config.global_compression_ratio == global_compression_ratio | ||
assert config.ignore == ignore | ||
assert config.quantization_status == QuantizationStatus.COMPRESSED | ||
|
||
|
||
def test_need_config_groups(): | ||
with pytest.raises(ValidationError): | ||
_ = QuantizationScheme() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import pytest | ||
from pydantic import ValidationError | ||
from sparsetensors.quantization import QuantizationArgs, QuantizationScheme | ||
|
||
|
||
def test_basic_scheme(): | ||
targets = ["model.layer.0", "model.layer.3"] | ||
weights = QuantizationArgs() | ||
|
||
scheme = QuantizationScheme(targets=targets, weights=weights) | ||
assert scheme.targets == targets | ||
assert scheme.weights == weights | ||
assert scheme.input_activations is None | ||
assert scheme.output_activations is None | ||
|
||
|
||
def test_full_scheme(): | ||
targets = ["Linear"] | ||
weights = QuantizationArgs() | ||
input_activations = QuantizationArgs(num_bits=4) | ||
output_activations = QuantizationArgs(num_bits=8, type="float", symmetric=False) | ||
|
||
scheme = QuantizationScheme( | ||
targets=targets, | ||
weights=weights, | ||
input_activations=input_activations, | ||
output_activations=output_activations, | ||
) | ||
assert scheme.targets == targets | ||
assert scheme.weights == weights | ||
assert scheme.input_activations == input_activations | ||
assert scheme.output_activations == output_activations | ||
|
||
|
||
def test_needs_targets(): | ||
with pytest.raises(ValidationError): | ||
_ = QuantizationScheme() |