Skip to content

Commit

Permalink
allow HFQuantizer to run run_compressed=False mode
Browse files Browse the repository at this point in the history
  • Loading branch information
horheynm committed Nov 18, 2024
1 parent ff121cc commit 81fa1bb
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 5 deletions.
2 changes: 2 additions & 0 deletions src/compressed_tensors/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ def compress_module(self, module: Module) -> Optional[Dict[str, torch.Tensor]]:
return None # module is not quantized
quantization_scheme = module.quantization_scheme
if not hasattr(quantization_scheme, "weights"):
# models that ran CompressedLinear.from_linear will
# run delattr(module, "weight")
return None # weights are not quantized

quantization_args = quantization_scheme.weights
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ def from_pretrained(
:return: compressor for the configs, or None if model is not compressed
"""
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
compression_config = getattr(config, COMPRESSION_CONFIG_NAME, None)
compression_config = getattr(config, COMPRESSION_CONFIG_NAME, None) or getattr(
config, QUANTIZATION_CONFIG_NAME, None
)
return cls.from_compression_config(compression_config)

@classmethod
Expand Down
4 changes: 3 additions & 1 deletion src/compressed_tensors/linear/compressed_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, Tuple

import torch
from compressed_tensors.compressors.base import BaseCompressor
from compressed_tensors.quantization import (
Expand Down Expand Up @@ -53,7 +55,7 @@ def from_linear(
)

# get the shape and dtype of compressed parameters
compression_params = module.compressor.compression_param_info(
compression_params: Dict[str, Tuple] = module.compressor.compression_param_info(
module.weight.shape, quantization_scheme.weights
)

Expand Down
3 changes: 2 additions & 1 deletion src/compressed_tensors/quantization/lifecycle/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ def apply_quantization_config(
model: Module, config: Union[QuantizationConfig, None], run_compressed: bool = False
) -> OrderedDict:
"""
Initializes the model for quantization in-place based on the given config
Initializes the model for quantization in-place based on the given config.
Optionally coverts quantizable modules to compressed_linear modules
:param model: model to apply quantization config to
:param config: quantization config
Expand Down
16 changes: 14 additions & 2 deletions src/compressed_tensors/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
__all__ = [
"QuantizationStatus",
"QuantizationConfig",
"QuantizationConfigHFQuantizer",
"LIFECYCLE_ORDER",
"DEFAULT_QUANTIZATION_METHOD",
"DEFAULT_QUANTIZATION_FORMAT",
Expand Down Expand Up @@ -132,9 +133,9 @@ class QuantizationConfig(BaseModel):
`k_proj` and `v_proj` in their names. If this is not the case
and kv_cache_scheme != None, the quantization of kv cache will fail
:global_compression_ratio: optional informational config to report the model
compression ratio acheived by the quantization config
compression ratio acheived by the quantization config
:ignore: optional list of layers to ignore from config_groups. Layers in this list
are not quantized even if they match up with a target in config_groups
are not quantized even if they match up with a target in config_groups
"""

config_groups: Dict[str, Union[QuantizationScheme, List[str]]]
Expand Down Expand Up @@ -262,3 +263,14 @@ def requires_calibration_data(self):
return True

return False


# For HFQuantizer, be able to adjust run_compressed on model load
class QuantizationConfigHFQuantizer(QuantizationConfig):
"""
:param run_compressed: param used set run_compressed.
Used for `apply_quantization_config` in CompressedTensorsHfQuantizer
in transformers
"""

run_compressed: bool = True

0 comments on commit 81fa1bb

Please sign in to comment.