Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Define Quantization within SparseGPTModifier #1776

Merged
merged 9 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/sparseml/core/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,11 @@ def set_param(self, target: str, param: PT):
:param param: the param instance to set
"""
raise NotImplementedError()

def qat_active(self) -> bool:
"""
Checks if quantization aware training is set up in the model

:return: True if QAT is active in any layer, False otherwise
"""
raise NotImplementedError()
9 changes: 9 additions & 0 deletions src/sparseml/core/model/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
get_layers_params,
get_param,
get_params,
qat_active,
set_layer,
set_param,
)
Expand Down Expand Up @@ -94,3 +95,11 @@ def set_param(self, target: str, param: Parameter):
:param param: the parameter to set
"""
return set_param(target, param, self.model)

def qat_active(self) -> bool:
"""
Checks if quantization aware training is set up in the model

:return: True if QAT is active in any layer, False otherwise
"""
return qat_active(self.model)
82 changes: 77 additions & 5 deletions src/sparseml/modifiers/obcq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import List, Optional, Union
import logging
from typing import Any, Dict, List, Optional, Union

from sparseml.core import Modifier
from sparseml.core.factory import ModifierFactory
from sparseml.core.state import State
from sparseml.utils import ALL_TOKEN


__all__ = ["SparseGPTModifier"]

_LOGGER = logging.getLogger(__name__)


class SparseGPTModifier(Modifier):
"""
Expand Down Expand Up @@ -50,7 +53,7 @@ class SparseGPTModifier(Modifier):

sparsity: float
block_size: int
quantize: bool
quantize: Union[bool, Dict]
Satrat marked this conversation as resolved.
Show resolved Hide resolved
dampening_frac: Optional[float] = 0.01
sequential_update: Optional[bool] = True
prunen: Optional[int] = 0
Expand All @@ -59,5 +62,74 @@ class SparseGPTModifier(Modifier):
target_ids: Optional[List[str]] = None
layer_prefix: Optional[str] = None

def on_initialize_structure(self, state: "State", **kwargs):
pass # nothing needed for this modifier
compressible_layers_: List = None
quantization_modifier_: Any = None

def compressible_layers(self) -> List:
"""
Retrieves the modules corresponding to a list of compressible layer names

:return: list of Pytorch modules to compress
"""
compressible_dict = self.model.get_layers(self.targets)
return [v for _, v in compressible_dict.items()]

def pre_initialize_structure(self, state: State, **kwargs):
quantization_already_active = state.model.qat_active()
if isinstance(self.quantize, bool):
if not self.quantize and quantization_already_active:
_LOGGER.warning(
"SparseGPT quantization is set to False, but a "
"quantization modifier is already active on the model "
"resetting quantize to True"
)
self.quantize = True
elif self.quantize and not quantization_already_active:
_LOGGER.warning(
"SparseGPT quantization is set to True without an "
"active quantization modifier. Creating a default "
"8-bit quantization modifier"
)
default_quant_config = {"QuantizationModifier": {}}
self._build_quant_modifier_from_dict(
default_quant_config, state.framework
)
return # use existing quantization modifier if there is one
else:
if not isinstance(self.quantize, Dict):
raise ValueError(
"SparseGPTModifier.quantize accepts only a single "
"quantization modifier or a boolean. Found "
f"type {type(self.quantize)}"
)
if len(self.quantize) != 1:
raise ValueError(
"SparseGPTModifier.quantize accepts only a single "
"quantization modifier or a boolean. Found "
f"{len(self.quantize)} modifiers"
)
if quantization_already_active:
_LOGGER.warning(
"Attempting to initialize quantization for SparseGPT "
"but a quantization modifier has already been applied. "
"The quantization configuration defined under the "
"SparseGPT modifier will be ignored."
)
self.quantize = True
return
self._build_quant_modifier_from_dict(self.quantize, state.framework)
Satrat marked this conversation as resolved.
Show resolved Hide resolved
self.quantize = True

if self.quantization_modifier_:
self.quantization_modifier_.pre_initialize_structure(state, **kwargs)

def _build_quant_modifier_from_dict(self, quant_config, framework):
modifier_type = list(quant_config.keys())[0]
modifier_args = quant_config[modifier_type]
self.quantization_modifier_ = ModifierFactory.create(
modifier_type,
framework=framework,
allow_registered=True,
allow_experimental=True,
**modifier_args,
)
19 changes: 7 additions & 12 deletions src/sparseml/modifiers/obcq/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from typing import Any, Dict, Iterable, List, Optional, Tuple

import torch
from torch.nn import Module

from sparseml.core.model import ModifiableModel
from sparseml.core.state import State
Expand Down Expand Up @@ -46,25 +45,19 @@ class SparseGPTModifierPyTorch(SparseGPTModifier):
"""

model: Any = None
compressible_layers_: List = None
device_: str = "cuda:0"
finalization_kwargs_: Dict = None

def compressible_layers(self) -> List[Module]:
"""
Retrieves the modules corresponding to a list of compressible layer names

:return: list of Pytorch modules to compress
"""
compressible_dict = self.model.get_layers(self.targets)
return [v for _, v in compressible_dict.items()]

def on_initialize(self, state: "State", **kwargs) -> bool:
"""
Initialize and run the OBCQ algorithm on the current state

:param state: session state storing input model and calibration data
"""
if not self.initialized_structure_:
self.pre_initialize_structure(state, **kwargs)
if self.quantization_modifier_:
self.quantization_modifier_.initialize(state, **kwargs)
self.finalization_kwargs_ = {}
modifiable_model = state.model
calibration_dataloader = state.data.calib
Expand Down Expand Up @@ -151,9 +144,11 @@ def on_finalize(self, state: "State", **kwargs) -> bool:
:param state: un-used, for matching spec of Modifier base class
"""
use_cache = self.finalization_kwargs_.get("use_cache", False)
self.model.apply(torch.quantization.disable_observer)
self.model.config.use_cache = use_cache

if self.quantization_modifier_:
self.quantization_modifier_.finalize(state, **kwargs)

return True

def compress_bottom(
Expand Down
22 changes: 11 additions & 11 deletions src/sparseml/transformers/sparsification/obcq/example.yaml
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
test_stage:
obcq_modifiers:
QuantizationModifier:
ignore: ["lm_head", "Embedding", "OPTLearnedPositionalEmbedding", "QuantizableBatchMatMul", "BMMLeftInput_QK", "BMMRightInput_QK", "BMMOutput_QK", "BMMLeftInput_PV", "BMMRightInput_PV", "BMMOutput_PV"]
post_oneshot_calibration: True
scheme_overrides:
ReLU:
input_activations: null
output_activations: null
LayerNorm:
input_activations: null
output_activations: null
SparseGPTModifier:
sparsity: 0.5
block_size: 128
sequential_update: False
quantize: True
quantize:
QuantizationModifier:
ignore: ["lm_head", "Embedding", "OPTLearnedPositionalEmbedding", "QuantizableBatchMatMul", "BMMLeftInput_QK", "BMMRightInput_QK", "BMMOutput_QK", "BMMLeftInput_PV", "BMMRightInput_PV", "BMMOutput_PV"]
post_oneshot_calibration: True
scheme_overrides:
ReLU:
input_activations: null
output_activations: null
LayerNorm:
input_activations: null
output_activations: null
percdamp: 0.01
prunen: 0
prunem: 0
Expand Down
16 changes: 16 additions & 0 deletions src/sparseml/utils/pytorch/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
"get_terminal_layers",
"get_prunable_layers",
"get_quantizable_layers",
"qat_active",
"get_layers_params",
]

Expand Down Expand Up @@ -241,6 +242,21 @@ def get_quantizable_layers(module: Module) -> Dict[str, Module]:
return quantizable


def qat_active(module: Module) -> bool:
"""
Determines if any layers in the model have quantization enabled by checking for
weight_fake_quant attributes

:param module: PyTorch model to check for quantization
:return: True if quantization is active anywhere in the model, False otherwise
"""
for _, layer in module.named_modules():
if isinstance(layer, torch.quantization.FakeQuantize):
return True

return False


def get_layers_params(
targets: Union[str, List[str]], module: Module
) -> Dict[str, ModelParameterizedLayer[Parameter, Module]]:
Expand Down
13 changes: 13 additions & 0 deletions tests/sparseml/pytorch/modifiers/obcq/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
107 changes: 107 additions & 0 deletions tests/sparseml/pytorch/modifiers/obcq/test_pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from sparseml.modifiers.obcq.pytorch import SparseGPTModifierPyTorch
from sparseml.modifiers.quantization import QuantizationModifier
from sparseml.modifiers.quantization.pytorch import QuantizationModifierPyTorch
from tests.sparseml.modifiers.conf import LifecyleTestingHarness, setup_modifier_factory
from tests.sparseml.pytorch.helpers import LinearNet


def test_create_default_quant_modifier():
setup_modifier_factory()
kwargs = dict(sparsity=0.5, block_size=128, quantize=True)

modifier = SparseGPTModifierPyTorch(**kwargs)
assert modifier.quantization_modifier_ is None

testing_harness = LifecyleTestingHarness(model=LinearNet())
modifier.pre_initialize_structure(testing_harness.get_state())
assert modifier.quantize
assert isinstance(modifier.quantization_modifier_, QuantizationModifier)

should_be_default_quant_scheme = modifier.quantization_modifier_.scheme
assert should_be_default_quant_scheme.input_activations.num_bits == 8
assert not should_be_default_quant_scheme.input_activations.symmetric
assert should_be_default_quant_scheme.weights.num_bits == 8
assert should_be_default_quant_scheme.weights.symmetric


def test_set_quant_if_modifer_already_exists():
setup_modifier_factory()

model = LinearNet()
kwargs = dict(
scheme=dict(
input_activations=dict(num_bits=8, symmetric=True),
weights=dict(num_bits=4, symmetric=False),
),
)

modifier = QuantizationModifierPyTorch(**kwargs)
testing_harness = LifecyleTestingHarness(model=model)

assert not testing_harness.get_state().model.qat_active()
modifier.initialize(testing_harness.get_state())
assert testing_harness.get_state().model.qat_active()

kwargs = dict(sparsity=0.5, block_size=128, quantize=False)
modifier = SparseGPTModifierPyTorch(**kwargs)
assert not modifier.quantize
modifier.pre_initialize_structure(testing_harness.get_state())

# quantization modifier not owned by SparseGPT
assert modifier.quantization_modifier_ is None

# since quantization modifier is already applied, quantization must be set in OBCQ
assert modifier.quantize


def test_set_quant_in_sparsegpt():
setup_modifier_factory()

quant_kwargs = {
"scheme": {
"input_activations": {
"num_bits": 8,
"symmetric": False,
"strategy": "tensor",
"kwargs": {},
},
"weights": {
"num_bits": 4,
"symmetric": True,
"strategy": "channel",
"kwargs": {},
},
}
}
quant_config = {"QuantizationModifier": quant_kwargs}

kwargs = dict(sparsity=0.5, block_size=128, quantize=quant_config)

modifier = SparseGPTModifierPyTorch(**kwargs)
assert modifier.quantization_modifier_ is None

testing_harness = LifecyleTestingHarness(model=LinearNet())
modifier.pre_initialize_structure(testing_harness.get_state())
assert modifier.quantize
assert isinstance(modifier.quantization_modifier_, QuantizationModifier)

dict_scheme = dict(modifier.quantization_modifier_.scheme)
assert dict(dict_scheme["weights"]) == quant_kwargs["scheme"]["weights"]
assert (
dict(dict_scheme["input_activations"])
== quant_kwargs["scheme"]["input_activations"]
)
Loading