Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quantization Compressor Support #2260

Merged
merged 70 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from 62 commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
097bd79
initial commit
dbogunowicz Apr 8, 2024
76970e3
update setup.py
dbogunowicz Apr 8, 2024
bbf4b39
Update setup.py
dbogunowicz Apr 8, 2024
a272a30
fix setup.py
dbogunowicz Apr 8, 2024
c0d3ead
move all config to sparsetensors
Apr 10, 2024
b3f7ff3
Merge branch 'main' into feature/damian/sparsetensors
Apr 10, 2024
a75f8da
cleanup class name and comments
Apr 10, 2024
c5b897e
Merge branch 'main' into feature/damian/sparsetensors
Apr 16, 2024
2c72ab1
initial implementation untested
Apr 16, 2024
9174c1d
fixing issues
Apr 16, 2024
aa17e77
add test script
Apr 17, 2024
f1f114c
update perplexity test
Apr 17, 2024
bbbdcb9
refactor to compressed-tensors
dbogunowicz Apr 18, 2024
5d9c7dd
Merge branch 'main' into feature/damian/sparsetensors
dbogunowicz Apr 18, 2024
7a9f9e5
rename sparsetensors
Apr 18, 2024
fa43088
update setup
Apr 18, 2024
63266d8
Sa/model reload (#2250)
Apr 19, 2024
b0f0fc9
Merge branch 'main' into sa/quant_mod_refactor
Apr 22, 2024
dfa41fb
Merge branch 'main' into feature/damian/sparsetensors
Apr 22, 2024
4af4852
Merge branch 'feature/damian/sparsetensors' into sa/quant_mod_refactor
Apr 22, 2024
55976c5
cleanup
Apr 22, 2024
38f4f77
refactor tests
Apr 22, 2024
6574874
only run oneshot once
Apr 22, 2024
7f5babf
all tests passing
dbogunowicz Apr 23, 2024
c0d6cb9
remove unused config
dbogunowicz Apr 23, 2024
a59e2af
reset models on each parameterize
Apr 23, 2024
cba7c27
Merge branch 'feature/damian/sparsetensors' into sa/quant_mod_refactor
Apr 23, 2024
2a6b0f2
style
Apr 23, 2024
1e7ee94
Merge branch 'main' into feature/damian/sparsetensors
dbogunowicz Apr 24, 2024
a4e0575
bring back SparsityConfigMetadata
dbogunowicz Apr 24, 2024
06d4554
Merge branch 'feature/damian/sparsetensors' of github.com:neuralmagic…
dbogunowicz Apr 24, 2024
644da53
Merge remote-tracking branch 'origin/feature/damian/sparsetensors' in…
dbogunowicz Apr 24, 2024
8ac18e7
Update setup.py
dbogunowicz Apr 24, 2024
de78247
add more comparisons, tighten threshold
Apr 25, 2024
4041f2e
use wikitext for perplexity
Apr 25, 2024
f5adc4e
Merge branch 'main' into feature/damian/sparsetensors
dbogunowicz Apr 25, 2024
c220772
update setup
dbogunowicz Apr 25, 2024
2fe554e
fix import problem
dbogunowicz Apr 25, 2024
4e0413e
fix clearml test
dbogunowicz Apr 25, 2024
a98a193
compressed-tensors are transformers dep
dbogunowicz Apr 25, 2024
b9b684c
Merge branch 'feature/damian/sparsetensors' into sa/quant_mod_refactor
Apr 25, 2024
f4362cf
address PR comments
Apr 25, 2024
ca91c4f
can't repeat freeze
Apr 26, 2024
c894305
UX pr comments
Apr 26, 2024
604c4ef
initial commit
Apr 29, 2024
82d3dd8
style
Apr 29, 2024
b650a8c
skipping unit tests
Apr 30, 2024
6a12295
tests for quantization
Apr 30, 2024
c1e0379
reloading unit tests
Apr 30, 2024
ba397fc
backwards compat
Apr 30, 2024
0a0ef06
test updates
Apr 30, 2024
b03d138
update format
Apr 30, 2024
2931d8e
fix inferring
May 1, 2024
1c3b31b
Merge branch 'main' into sa/quant_mod_refactor
May 1, 2024
90795bd
quality
May 1, 2024
c287c05
Merge branch 'sa/quant_mod_refactor' into sa/compressors
May 1, 2024
bf7d0f6
shape consistency
horheynm May 1, 2024
579d201
Merge branch 'sa/quant_mod_refactor' of github.com:neuralmagic/sparse…
horheynm May 1, 2024
2432cf4
address PR comments
May 2, 2024
24437c7
PR comments
May 3, 2024
399087f
Merge branch 'sa/quant_mod_refactor' into sa/compressors
May 3, 2024
e8bc021
fixing some things
May 3, 2024
3ca0298
Merge branch 'main' into sa/compressors
May 7, 2024
061de67
style
May 7, 2024
633d5a5
Merge branch 'main' into sa/compressors
May 8, 2024
6e0f1bc
pull from cp main
May 8, 2024
3ff4dc8
postmerge too
May 8, 2024
6f4379c
Merge branch 'main' into sa/compressors
May 8, 2024
29a2186
export needs it too
May 8, 2024
e93257f
Update src/sparseml/modifiers/obcq/utils/sgpt_wrapper.py
May 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: 'NO'
enable_cpu_affinity: false
gpu_ids: 0
machine_rank: 0
main_training_function: main
num_machines: 1
num_processes: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
34 changes: 34 additions & 0 deletions src/sparseml/modifiers/obcq/utils/sgpt_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,40 @@ def fasterprune(
else:
q = torch.quantize_per_channel(q, scale, zero_point, 0, dtype)
q = torch.dequantize(q)
elif hasattr(self.layer, "quantization_scheme"):
quant_scheme = self.layer.quantization_scheme
if quant_scheme.weights is not None:
scale = self.layer.weight_scale
zero_point = self.layer.weight_zero_point
from compressed_tensors.quantization import QuantizationStrategy
from compressed_tensors.quantization.lifecycle.forward import (
fake_quantize,
)

if quant_scheme.weights.strategy == QuantizationStrategy.TENSOR:
q = fake_quantize(
q,
scale,
zero_point,
self.layer.quantization_scheme.weights,
)
else:
while scale.ndim < 2:
scale = scale.unsqueeze(scale.ndim)
zero_point = zero_point.unsqueeze(zero_point.ndim)

while q.ndim < 2:
q = q.unsqueeze(q.ndim)

q = fake_quantize(
q,
scale[:, i],
zero_point[:, i],
self.layer.quantization_scheme.weights,
)

while q.ndim > 1:
q.squeeze()
Satrat marked this conversation as resolved.
Show resolved Hide resolved

Q1[:, i] = q
Losses1[:, i] = (w - q) ** 2 / d**2
Expand Down
17 changes: 17 additions & 0 deletions src/sparseml/modifiers/quantization_vllm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa

from .base import *
83 changes: 83 additions & 0 deletions src/sparseml/modifiers/quantization_vllm/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, Optional

from pydantic import Field

from compressed_tensors.quantization import (
QuantizationConfig,
QuantizationScheme,
QuantizationStatus,
)
from sparseml.core import Event, Modifier


__all__ = ["vLLMQuantizationModifier"]


class vLLMQuantizationModifier(Modifier):
"""
Enables post training quantization (PTQ) and quantization aware training (QAT) for a
given module or its submodules. After calibration (PTQ) or the start epoch (QAT),
the specified module(s) forward pass will emulate quantized execution and the
modifier will be enabled until training is completed.

:param config_groups: dictionary specifying quantization schemes to apply to target
modules. Modules not matching a scheme target will NOT be quantized.
:param ignore: optional list of module class names or submodule names to not
quantize even if they match a target in config_groups. Defaults to empty list.
:param disable_quantization_observer_epoch: Epoch to disable updates to the module
quantization observers. At this point, quantized weights and zero points will
not be updated. Leave None to not disable observers during QAT. Default is None
:param num_calibration_steps: Number of steps to run post training calibration for.
When None, the entire calibration_dataloader is used
"""

config_groups: Dict[str, QuantizationScheme]
ignore: List[str] = Field(default_factory=list)
disable_quantization_observer_epoch: Optional[float] = None
num_calibration_steps: Optional[int] = None

def create_init_config(self) -> QuantizationConfig:
return QuantizationConfig(
config_groups=self.config_groups,
quantization_status=QuantizationStatus.INITIALIZED,
ignore=self.ignore,
)

def calculate_disable_observer_epoch(self) -> float:
"""
Get the epoch at which we want to disable to quantization observer
:return epoch to disable at, or -1 if it is not set
"""
return (
self.disable_quantization_observer_epoch
if self.disable_quantization_observer_epoch is not None
else -1
)

def check_should_disable_observer(self, event: Event) -> bool:
"""
Given the current index, determine if we should disable the observer

:param event: Event to get index from
:return: True if observer should be disabled, False otherwise
"""
disable_epoch = self.calculate_disable_observer_epoch()
if disable_epoch == -1:
return False
if event.current_index >= disable_epoch:
return True
return False
141 changes: 141 additions & 0 deletions src/sparseml/modifiers/quantization_vllm/pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Any

from torch.nn import Module

from compressed_tensors.quantization import (
apply_quantization_config,
freeze_module_quantization,
set_module_for_calibration,
)
from sparseml.core import Event, EventType, State
from sparseml.modifiers.quantization_vllm.base import vLLMQuantizationModifier
from sparseml.modifiers.utils.pytorch_helpers import run_calibration_forward


_LOGGER = logging.getLogger(__name__)


class vLLMQuantizationModifierPyTorch(vLLMQuantizationModifier):
"""
PyTorch specific implementation of vLLMQuantizationModifier

Enables post training quantization (PTQ) and quantization aware training (QAT) for a
given module or its submodules. After calibration (PTQ) or the start epoch (QAT),
the specified module(s) forward pass will emulate quantized execution and the
modifier will be enabled until training is completed.

:param config_groups: dictionary specifying quantization schemes to apply to target
modules. Modules not matching a scheme target will NOT be quantized.
:param ignore: optional list of module class names or submodule names to not
quantize even if they match a target in config_groups. Defaults to empty list.
:param disable_quantization_observer_epoch: Epoch to disable updates to the module
quantization observers. At this point, quantized weights and zero points will
not be updated. Leave None to not disable observers during QAT. Default is None
:param num_calibration_steps: Number of steps to run post training calibration for.
When None, the entire calibration_dataloader is used
"""

calibration_dataloader_: Any = None
calibration_function_: Any = None

def on_initialize_structure(self, state: State, **kwargs):
module = state.model.model
self._apply_modifier_to_model(module)
module.apply(freeze_module_quantization)

def on_initialize(self, state: State, **kwargs) -> bool:
if self.end and self.end != -1:
raise ValueError(
"end_epoch is disabled for QuantizationModifier and can only be set to"
" -1 or None. Given {}".format(self.end)
)

self.calibration_dataloader_ = state.data.calib
module = state.model.model

# intialize quantization in appropriate modules
self._apply_modifier_to_model(module)

if self.calculate_start() == -1: # one-shot
module.apply(set_module_for_calibration)
self._calibrate_if_possible(module)
module.apply(freeze_module_quantization)

return True

def on_finalize(self, state: State, **kwargs) -> bool:
return True

def on_start(self, state: State, event: Event, **kwargs):
module = state.model.model
module.apply(set_module_for_calibration)

def on_update(self, state: State, event: Event, **kwargs):
if event.type_ == EventType.BATCH_START:
if self.check_should_disable_observer(event):
module = state.model.model
module.apply(freeze_module_quantization)

def on_end(self, state: State, event: Event, **kwargs):
module = state.model.model
module.apply(freeze_module_quantization)

def on_event(self, state: State, event: Event, **kwargs):
pass

def _apply_modifier_to_model(self, model: Module):
modifier_as_config = self.create_init_config()
apply_quantization_config(model, modifier_as_config)

def _calibrate_if_possible(self, module: Module):
if self.num_calibration_steps == 0 and self.calibration_dataloader_:
_LOGGER.warning(
f"num_calibration_steps is {self.num_calibration_steps}."
f"Calibration data loader will not be used."
)
elif self.num_calibration_steps and not self.calibration_dataloader_:
raise ValueError(
f"num_calibration_steps is {self.num_calibration_steps}. "
"Calibration data loader is not set. Pass a "
"calibration_data_loader with initialize(...) method."
)

elif not self.calibration_dataloader_:
return

self._calibrate(module)

def _calibrate(self, module: Module):
class_name = self.__class__.__name__.replace("PyTorch", "")
_LOGGER.info(
f"Running {class_name} calibration with "
f"{len(self.calibration_dataloader_)} samples..."
)

module_training = module.training
module.eval()

run_calibration_forward(
module,
self.calibration_dataloader_,
self.num_calibration_steps,
self.calibration_function_,
)

if module_training:
module.train()
48 changes: 48 additions & 0 deletions src/sparseml/transformers/compression/quantization_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import Optional

from compressed_tensors import CompressionFormat
from compressed_tensors.quantization.utils import is_model_quantized


__all__ = ["infer_quantization_format"]


def infer_quantization_format(
model, quantization_format: Optional[str] = None, save_compressed: bool = False
) -> str:
"""
Infers a quantization format based on model state and compression args

:param model: model to check for quantization, if the model is not quantized no
quantization format is returned
:param quantization_format: user provided quantization format, supercedes any
inferred quantization format
:param save_compressed: used to infer a quantization format if None is provided
:return compression format appropriate for model
"""
if not is_model_quantized(model):
return None

if quantization_format is not None:
return quantization_format

if save_compressed:
return CompressionFormat.int_quantized
else:
# format will be inferred from config
return None
Loading