From 2a8748b064a2bbadbc273dec0fb80f9874338d84 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Tue, 17 Oct 2023 17:20:29 -0400 Subject: [PATCH 1/5] basic implementation working --- src/sparseml/modifiers/obcq/base.py | 4 +- src/sparseml/modifiers/obcq/pytorch.py | 43 ++++++++++++++++++- .../sparsification/obcq/example.yaml | 22 +++++----- 3 files changed, 55 insertions(+), 14 deletions(-) diff --git a/src/sparseml/modifiers/obcq/base.py b/src/sparseml/modifiers/obcq/base.py index fe66f61e505..b8f69f0958f 100644 --- a/src/sparseml/modifiers/obcq/base.py +++ b/src/sparseml/modifiers/obcq/base.py @@ -13,7 +13,7 @@ # limitations under the License. -from typing import List, Optional, Union +from typing import Dict, List, Optional, Union from sparseml.core import Modifier from sparseml.core.state import State @@ -50,7 +50,7 @@ class SparseGPTModifier(Modifier): sparsity: float block_size: int - quantize: bool + quantize: Union[bool, Dict] dampening_frac: Optional[float] = 0.01 sequential_update: Optional[bool] = True prunen: Optional[int] = 0 diff --git a/src/sparseml/modifiers/obcq/pytorch.py b/src/sparseml/modifiers/obcq/pytorch.py index 3a043df9707..1d5429891df 100644 --- a/src/sparseml/modifiers/obcq/pytorch.py +++ b/src/sparseml/modifiers/obcq/pytorch.py @@ -19,6 +19,8 @@ import torch from torch.nn import Module +from sparseml.core.factory import ModifierFactory +from sparseml.core.framework import Framework from sparseml.core.model import ModifiableModel from sparseml.core.state import State from sparseml.modifiers.obcq.base import SparseGPTModifier @@ -49,6 +51,7 @@ class SparseGPTModifierPyTorch(SparseGPTModifier): compressible_layers_: List = None device_: str = "cuda:0" finalization_kwargs_: Dict = None + quantization_modifier_: Any = None def compressible_layers(self) -> List[Module]: """ @@ -59,12 +62,48 @@ def compressible_layers(self) -> List[Module]: compressible_dict = self.model.get_layers(self.compress_layers) return [v for _, v in compressible_dict.items()] + def pre_initialize_structure(self, state: State, **kwargs): + if isinstance(self.quantize, bool): + if not self.quantize: + return + + else: + if not isinstance(self.quantize, Dict): + raise ValueError( + "SparseGPTModifier.quantize accepts only a single " + "quantization modifier or a boolean. Found " + f"type {type(self.quantize)}" + ) + if len(self.quantize) != 1: + raise ValueError( + "SparseGPTModifier.quantize accepts only a single " + "quantization modifier or a boolean. Found " + f"{len(self.quantize)} modifiers" + ) + modifier_type = list(self.quantize.keys())[0] + modifier_args = self.quantize[modifier_type] + self.quantization_modifier_ = ModifierFactory.create( + modifier_type, + framework=Framework.pytorch, + allow_registered=True, + allow_experimental=True, + **modifier_args, + ) + self.quantize = True + + if self.quantization_modifier_: + self.quantization_modifier_.pre_initialize_structure(state, **kwargs) + def on_initialize(self, state: "State", **kwargs) -> bool: """ Initialize and run the OBCQ algorithm on the current state :param state: session state storing input model and calibration data """ + if not self.initialized_structure_: + self.pre_initialize_structure(state, **kwargs) + if self.quantization_modifier_: + self.quantization_modifier_.initialize(state, **kwargs) self.finalization_kwargs_ = {} modifiable_model = state.model calibration_dataloader = state.data.calib @@ -151,9 +190,11 @@ def on_finalize(self, state: "State", **kwargs) -> bool: :param state: un-used, for matching spec of Modifier base class """ use_cache = self.finalization_kwargs_.get("use_cache", False) - self.model.apply(torch.quantization.disable_observer) self.model.config.use_cache = use_cache + if self.quantization_modifier_: + self.quantization_modifier_.finalize(state, **kwargs) + return True def compress_bottom( diff --git a/src/sparseml/transformers/sparsification/obcq/example.yaml b/src/sparseml/transformers/sparsification/obcq/example.yaml index 5987e220902..aa051b667d7 100644 --- a/src/sparseml/transformers/sparsification/obcq/example.yaml +++ b/src/sparseml/transformers/sparsification/obcq/example.yaml @@ -1,20 +1,20 @@ test_stage: obcq_modifiers: - QuantizationModifier: - ignore: ["lm_head", "Embedding", "OPTLearnedPositionalEmbedding", "QuantizableBatchMatMul", "BMMLeftInput_QK", "BMMRightInput_QK", "BMMOutput_QK", "BMMLeftInput_PV", "BMMRightInput_PV", "BMMOutput_PV"] - post_oneshot_calibration: True - scheme_overrides: - ReLU: - input_activations: null - output_activations: null - LayerNorm: - input_activations: null - output_activations: null SparseGPTModifier: sparsity: 0.5 block_size: 128 sequential_update: False - quantize: True + quantize: + QuantizationModifier: + ignore: ["lm_head", "Embedding", "OPTLearnedPositionalEmbedding", "QuantizableBatchMatMul", "BMMLeftInput_QK", "BMMRightInput_QK", "BMMOutput_QK", "BMMLeftInput_PV", "BMMRightInput_PV", "BMMOutput_PV"] + post_oneshot_calibration: True + scheme_overrides: + ReLU: + input_activations: null + output_activations: null + LayerNorm: + input_activations: null + output_activations: null percdamp: 0.01 prunen: 0 prunem: 0 From f31a5cb6575bd4c4579bb988441d4f0c692670c8 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Wed, 18 Oct 2023 17:21:25 -0400 Subject: [PATCH 2/5] qat active function and edge cases --- src/sparseml/core/model/base.py | 8 ++ src/sparseml/core/model/pytorch.py | 9 +++ src/sparseml/modifiers/obcq/base.py | 80 ++++++++++++++++++- src/sparseml/modifiers/obcq/pytorch.py | 46 ----------- .../sparsification/obcq/example.yaml | 6 +- src/sparseml/utils/pytorch/module.py | 16 ++++ 6 files changed, 114 insertions(+), 51 deletions(-) diff --git a/src/sparseml/core/model/base.py b/src/sparseml/core/model/base.py index 602682004a2..387114ad96a 100644 --- a/src/sparseml/core/model/base.py +++ b/src/sparseml/core/model/base.py @@ -116,3 +116,11 @@ def set_param(self, target: str, param: PT): :param param: the param instance to set """ raise NotImplementedError() + + def qat_active(self) -> bool: + """ + Checks if quantization aware training is set up in the model + + :return: True if QAT is active in any layer, False otherwise + """ + raise NotImplementedError() diff --git a/src/sparseml/core/model/pytorch.py b/src/sparseml/core/model/pytorch.py index 670d164900c..258675115ba 100644 --- a/src/sparseml/core/model/pytorch.py +++ b/src/sparseml/core/model/pytorch.py @@ -24,6 +24,7 @@ get_layers_params, get_param, get_params, + qat_active, set_layer, set_param, ) @@ -94,3 +95,11 @@ def set_param(self, target: str, param: Parameter): :param param: the parameter to set """ return set_param(target, param, self.model) + + def qat_active(self) -> bool: + """ + Checks if quantization aware training is set up in the model + + :return: True if QAT is active in any layer, False otherwise + """ + return qat_active(self.model) diff --git a/src/sparseml/modifiers/obcq/base.py b/src/sparseml/modifiers/obcq/base.py index b8f69f0958f..10b720b5846 100644 --- a/src/sparseml/modifiers/obcq/base.py +++ b/src/sparseml/modifiers/obcq/base.py @@ -12,16 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. - -from typing import Dict, List, Optional, Union +import logging +from typing import Any, Dict, List, Optional, Union from sparseml.core import Modifier +from sparseml.core.factory import ModifierFactory from sparseml.core.state import State from sparseml.utils import ALL_TOKEN __all__ = ["SparseGPTModifier"] +_LOGGER = logging.getLogger(__name__) + class SparseGPTModifier(Modifier): """ @@ -59,5 +62,74 @@ class SparseGPTModifier(Modifier): target_ids: Optional[List[str]] = None layer_prefix: Optional[str] = None - def on_initialize_structure(self, state: "State", **kwargs): - pass # nothing needed for this modifier + compressible_layers_: List = None + quantization_modifier_: Any = None + + def compressible_layers(self) -> List: + """ + Retrieves the modules corresponding to a list of compressible layer names + + :return: list of Pytorch modules to compress + """ + compressible_dict = self.model.get_layers(self.compress_layers) + return [v for _, v in compressible_dict.items()] + + def pre_initialize_structure(self, state: State, **kwargs): + quantization_already_active = state.model.qat_active() + if isinstance(self.quantize, bool): + if not self.quantize and quantization_already_active: + _LOGGER.warning( + "SparseGPT quantization is set to False, but a " + "quantization modifier is already active on the model " + "resetting quantize to True" + ) + self.quantize = True + elif self.quantize and not quantization_already_active: + _LOGGER.warning( + "SparseGPT quantization is set to True without an " + "active quantization modifier. Creating a default " + "8-bit quantization modifier" + ) + default_quant_config = {"QuantizationModifier": {}} + self._build_quant_modifier_from_dict( + default_quant_config, state.framework + ) + return # use existing quantization modifier if there is one + else: + if not isinstance(self.quantize, Dict): + raise ValueError( + "SparseGPTModifier.quantize accepts only a single " + "quantization modifier or a boolean. Found " + f"type {type(self.quantize)}" + ) + if len(self.quantize) != 1: + raise ValueError( + "SparseGPTModifier.quantize accepts only a single " + "quantization modifier or a boolean. Found " + f"{len(self.quantize)} modifiers" + ) + if quantization_already_active: + _LOGGER.warning( + "Attempting to initialize quantization for SparseGPT " + "but a quantization modifier has already been applied. " + "The quantization configuration defined under the " + "SparseGPT modifier will be ignored." + ) + self.quantize = True + return + self._build_quant_modifier_from_dict(self.quantize, state.framework) + self.quantize = True + + if self.quantization_modifier_: + self.quantization_modifier_.pre_initialize_structure(state, **kwargs) + + def _build_quant_modifier_from_dict(self, quant_config, framework): + modifier_type = list(quant_config.keys())[0] + modifier_args = quant_config[modifier_type] + self.quantization_modifier_ = ModifierFactory.create( + modifier_type, + framework=framework, + allow_registered=True, + allow_experimental=True, + **modifier_args, + ) diff --git a/src/sparseml/modifiers/obcq/pytorch.py b/src/sparseml/modifiers/obcq/pytorch.py index 1d5429891df..afdf530985f 100644 --- a/src/sparseml/modifiers/obcq/pytorch.py +++ b/src/sparseml/modifiers/obcq/pytorch.py @@ -17,10 +17,7 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple import torch -from torch.nn import Module -from sparseml.core.factory import ModifierFactory -from sparseml.core.framework import Framework from sparseml.core.model import ModifiableModel from sparseml.core.state import State from sparseml.modifiers.obcq.base import SparseGPTModifier @@ -48,51 +45,8 @@ class SparseGPTModifierPyTorch(SparseGPTModifier): """ model: Any = None - compressible_layers_: List = None device_: str = "cuda:0" finalization_kwargs_: Dict = None - quantization_modifier_: Any = None - - def compressible_layers(self) -> List[Module]: - """ - Retrieves the modules corresponding to a list of compressible layer names - - :return: list of Pytorch modules to compress - """ - compressible_dict = self.model.get_layers(self.compress_layers) - return [v for _, v in compressible_dict.items()] - - def pre_initialize_structure(self, state: State, **kwargs): - if isinstance(self.quantize, bool): - if not self.quantize: - return - - else: - if not isinstance(self.quantize, Dict): - raise ValueError( - "SparseGPTModifier.quantize accepts only a single " - "quantization modifier or a boolean. Found " - f"type {type(self.quantize)}" - ) - if len(self.quantize) != 1: - raise ValueError( - "SparseGPTModifier.quantize accepts only a single " - "quantization modifier or a boolean. Found " - f"{len(self.quantize)} modifiers" - ) - modifier_type = list(self.quantize.keys())[0] - modifier_args = self.quantize[modifier_type] - self.quantization_modifier_ = ModifierFactory.create( - modifier_type, - framework=Framework.pytorch, - allow_registered=True, - allow_experimental=True, - **modifier_args, - ) - self.quantize = True - - if self.quantization_modifier_: - self.quantization_modifier_.pre_initialize_structure(state, **kwargs) def on_initialize(self, state: "State", **kwargs) -> bool: """ diff --git a/src/sparseml/transformers/sparsification/obcq/example.yaml b/src/sparseml/transformers/sparsification/obcq/example.yaml index aa051b667d7..d971eed1385 100644 --- a/src/sparseml/transformers/sparsification/obcq/example.yaml +++ b/src/sparseml/transformers/sparsification/obcq/example.yaml @@ -45,4 +45,8 @@ test_stage: "model.decoder.layers.23" ] target_ids: ["attention_mask"] - layer_prefix: "decoder" \ No newline at end of file + layer_prefix: "decoder" + +# Llama model.model.layer.0 +# OPT model model.model.decoder.layer.0 +# other model model.model.encoder.extra.for.some.reason.layer.0 \ No newline at end of file diff --git a/src/sparseml/utils/pytorch/module.py b/src/sparseml/utils/pytorch/module.py index 05fae4a174a..fb3956016b0 100644 --- a/src/sparseml/utils/pytorch/module.py +++ b/src/sparseml/utils/pytorch/module.py @@ -65,6 +65,7 @@ "get_terminal_layers", "get_prunable_layers", "get_quantizable_layers", + "qat_active", "get_layers_params", ] @@ -241,6 +242,21 @@ def get_quantizable_layers(module: Module) -> Dict[str, Module]: return quantizable +def qat_active(module: Module) -> bool: + """ + Determines if any layers in the model have quantization enabled by checking for + weight_fake_quant attributes + + :param module: PyTorch model to check for quantization + :return: True if quantization is active anywhere in the model, False otherwise + """ + for _, layer in module.named_modules(): + if isinstance(layer, torch.quantization.FakeQuantize): + return True + + return False + + def get_layers_params( targets: Union[str, List[str]], module: Module ) -> Dict[str, ModelParameterizedLayer[Parameter, Module]]: From a72c5b135e19769c452086b90fc04052f21873e2 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Wed, 18 Oct 2023 21:19:40 -0400 Subject: [PATCH 3/5] tests for obcq quant --- .../pytorch/modifiers/obcq/__init__.py | 13 +++ .../pytorch/modifiers/obcq/test_pytorch.py | 107 ++++++++++++++++++ 2 files changed, 120 insertions(+) create mode 100644 tests/sparseml/pytorch/modifiers/obcq/__init__.py create mode 100644 tests/sparseml/pytorch/modifiers/obcq/test_pytorch.py diff --git a/tests/sparseml/pytorch/modifiers/obcq/__init__.py b/tests/sparseml/pytorch/modifiers/obcq/__init__.py new file mode 100644 index 00000000000..0c44f887a47 --- /dev/null +++ b/tests/sparseml/pytorch/modifiers/obcq/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/sparseml/pytorch/modifiers/obcq/test_pytorch.py b/tests/sparseml/pytorch/modifiers/obcq/test_pytorch.py new file mode 100644 index 00000000000..709a932e3cb --- /dev/null +++ b/tests/sparseml/pytorch/modifiers/obcq/test_pytorch.py @@ -0,0 +1,107 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sparseml.modifiers.obcq.pytorch import SparseGPTModifierPyTorch +from sparseml.modifiers.quantization import QuantizationModifier +from sparseml.modifiers.quantization.pytorch import QuantizationModifierPyTorch +from tests.sparseml.modifiers.conf import LifecyleTestingHarness, setup_modifier_factory +from tests.sparseml.pytorch.helpers import LinearNet + + +def test_create_default_quant_modifier(): + setup_modifier_factory() + kwargs = dict(sparsity=0.5, block_size=128, quantize=True) + + modifier = SparseGPTModifierPyTorch(**kwargs) + assert modifier.quantization_modifier_ is None + + testing_harness = LifecyleTestingHarness(model=LinearNet()) + modifier.pre_initialize_structure(testing_harness.get_state()) + assert modifier.quantize + assert isinstance(modifier.quantization_modifier_, QuantizationModifier) + + should_be_default_quant_scheme = modifier.quantization_modifier_.scheme + assert should_be_default_quant_scheme.input_activations.num_bits == 8 + assert not should_be_default_quant_scheme.input_activations.symmetric + assert should_be_default_quant_scheme.weights.num_bits == 8 + assert should_be_default_quant_scheme.weights.symmetric + + +def test_set_quant_if_modifer_already_exists(): + setup_modifier_factory() + + model = LinearNet() + kwargs = dict( + scheme=dict( + input_activations=dict(num_bits=8, symmetric=True), + weights=dict(num_bits=4, symmetric=False), + ), + ) + + modifier = QuantizationModifierPyTorch(**kwargs) + testing_harness = LifecyleTestingHarness(model=model) + + assert not testing_harness.get_state().model.qat_active() + modifier.initialize(testing_harness.get_state()) + assert testing_harness.get_state().model.qat_active() + + kwargs = dict(sparsity=0.5, block_size=128, quantize=False) + modifier = SparseGPTModifierPyTorch(**kwargs) + assert not modifier.quantize + modifier.pre_initialize_structure(testing_harness.get_state()) + + # quantization modifier not owned by SparseGPT + assert modifier.quantization_modifier_ is None + + # since quantization modifier is already applied, quantization must be set in OBCQ + assert modifier.quantize + + +def test_set_quant_in_sparsegpt(): + setup_modifier_factory() + + quant_kwargs = { + "scheme": { + "input_activations": { + "num_bits": 8, + "symmetric": False, + "strategy": "tensor", + "kwargs": {}, + }, + "weights": { + "num_bits": 4, + "symmetric": True, + "strategy": "channel", + "kwargs": {}, + }, + } + } + quant_config = {"QuantizationModifier": quant_kwargs} + + kwargs = dict(sparsity=0.5, block_size=128, quantize=quant_config) + + modifier = SparseGPTModifierPyTorch(**kwargs) + assert modifier.quantization_modifier_ is None + + testing_harness = LifecyleTestingHarness(model=LinearNet()) + modifier.pre_initialize_structure(testing_harness.get_state()) + assert modifier.quantize + assert isinstance(modifier.quantization_modifier_, QuantizationModifier) + + dict_scheme = dict(modifier.quantization_modifier_.scheme) + assert dict(dict_scheme["weights"]) == quant_kwargs["scheme"]["weights"] + assert ( + dict(dict_scheme["input_activations"]) + == quant_kwargs["scheme"]["input_activations"] + ) From ebdf35a498e72b1b875030b65c4bde495e247d8e Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Wed, 18 Oct 2023 21:20:09 -0400 Subject: [PATCH 4/5] clean recipe --- src/sparseml/transformers/sparsification/obcq/example.yaml | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/sparseml/transformers/sparsification/obcq/example.yaml b/src/sparseml/transformers/sparsification/obcq/example.yaml index d971eed1385..aa051b667d7 100644 --- a/src/sparseml/transformers/sparsification/obcq/example.yaml +++ b/src/sparseml/transformers/sparsification/obcq/example.yaml @@ -45,8 +45,4 @@ test_stage: "model.decoder.layers.23" ] target_ids: ["attention_mask"] - layer_prefix: "decoder" - -# Llama model.model.layer.0 -# OPT model model.model.decoder.layer.0 -# other model model.model.encoder.extra.for.some.reason.layer.0 \ No newline at end of file + layer_prefix: "decoder" \ No newline at end of file From d651ca0b933442afe6e5f5b1154de832a5deac0a Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Thu, 26 Oct 2023 16:41:37 -0400 Subject: [PATCH 5/5] docstrings for new quantization situation --- src/sparseml/modifiers/obcq/base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/sparseml/modifiers/obcq/base.py b/src/sparseml/modifiers/obcq/base.py index 8f08ac616aa..59aa076768b 100644 --- a/src/sparseml/modifiers/obcq/base.py +++ b/src/sparseml/modifiers/obcq/base.py @@ -37,7 +37,9 @@ class SparseGPTModifier(Modifier): :param sparsity: Sparsity to compress model to :param block_size: Used to determine number of columns to compress in one pass - :param quantize: Whether or not model is quantized (affects layer names) + :param quantize: Whether or not to quantize weights during SparseGPT. Set to True + to quantize using an existing quantization modifier, or pass in the configuration + for a quantization modifier if one does not already exist in the recipe :param dampening_frac: Amount of dampening to apply to H, as a fraction of the diagonal norm :param sequential_update: Whether or not to update weights sequentially by layer,