diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index 6473554d..68bd52ec 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -24,7 +24,6 @@ import torch import transformers from compressed_tensors.base import ( - COMPRESSION_CONFIG_NAME, COMPRESSION_VERSION_NAME, QUANTIZATION_CONFIG_NAME, QUANTIZATION_METHOD_NAME, @@ -39,6 +38,7 @@ apply_quantization_config, load_pretrained_quantization, ) +from compressed_tensors.quantization.quant_args import QuantizationArgs from compressed_tensors.quantization.utils import ( is_module_quantized, iter_named_leaf_modules, @@ -103,12 +103,14 @@ def from_pretrained( :return: compressor for the configs, or None if model is not compressed """ config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) - compression_config = getattr(config, COMPRESSION_CONFIG_NAME, None) + compression_config = getattr(config, QUANTIZATION_CONFIG_NAME, None) + return cls.from_compression_config(compression_config) @classmethod def from_compression_config( - cls, compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"] + cls, + compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"], ): """ :param compression_config: @@ -265,7 +267,11 @@ def compress( state_dict = model.state_dict() compressed_state_dict = state_dict - quantized_modules_to_args = map_modules_to_quant_args(model) + + quantized_modules_to_args: Dict[ + str, QuantizationArgs + ] = map_modules_to_quant_args(model) + if self.quantization_compressor is not None: compressed_state_dict = self.quantization_compressor.compress( state_dict, names_to_scheme=quantized_modules_to_args @@ -369,7 +375,13 @@ def _replace_weights(self, dense_weight_generator, model): update_parameter_data(module, data, param_name) -def map_modules_to_quant_args(model: Module) -> Dict: +def map_modules_to_quant_args(model: Module) -> Dict[str, QuantizationArgs]: + """ + Given a pytorch model, map out the submodule name (usually linear layers) + to the QuantizationArgs + + :param model: pytorch model + """ quantized_modules_to_args = {} for name, submodule in iter_named_leaf_modules(model): if is_module_quantized(submodule): diff --git a/src/compressed_tensors/linear/compressed_linear.py b/src/compressed_tensors/linear/compressed_linear.py index a4d5b532..3e2b2f5f 100644 --- a/src/compressed_tensors/linear/compressed_linear.py +++ b/src/compressed_tensors/linear/compressed_linear.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict, Tuple + import torch from compressed_tensors.compressors.base import BaseCompressor from compressed_tensors.quantization import ( @@ -53,7 +55,7 @@ def from_linear( ) # get the shape and dtype of compressed parameters - compression_params = module.compressor.compression_param_info( + compression_params: Dict[str, Tuple] = module.compressor.compression_param_info( module.weight.shape, quantization_scheme.weights ) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 7c498787..ed9a50f7 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -106,7 +106,8 @@ def apply_quantization_config( model: Module, config: Union[QuantizationConfig, None], run_compressed: bool = False ) -> OrderedDict: """ - Initializes the model for quantization in-place based on the given config + Initializes the model for quantization in-place based on the given config. + Optionally coverts quantizable modules to compressed_linear modules :param model: model to apply quantization config to :param config: quantization config diff --git a/src/compressed_tensors/quantization/quant_config.py b/src/compressed_tensors/quantization/quant_config.py index 04c8deb7..1d95aee8 100644 --- a/src/compressed_tensors/quantization/quant_config.py +++ b/src/compressed_tensors/quantization/quant_config.py @@ -132,9 +132,9 @@ class QuantizationConfig(BaseModel): `k_proj` and `v_proj` in their names. If this is not the case and kv_cache_scheme != None, the quantization of kv cache will fail :global_compression_ratio: optional informational config to report the model - compression ratio acheived by the quantization config + compression ratio acheived by the quantization config :ignore: optional list of layers to ignore from config_groups. Layers in this list - are not quantized even if they match up with a target in config_groups + are not quantized even if they match up with a target in config_groups """ config_groups: Dict[str, Union[QuantizationScheme, List[str]]] diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index 7fe90ca3..3a8152da 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -36,7 +36,7 @@ class QuantizationScheme(BaseModel): of modules should be quantized :param targets: list of modules to apply the QuantizationArgs to, can be layer - names, layer types or a regular expression + names, layer types or a regular expression, typically ["Linear"] :param weights: quantization config for layer weights :param input_activations: quantization config for layer inputs :param output_activations: quantization config for layer outputs @@ -62,28 +62,6 @@ def validate_model_after(model: "QuantizationArgs") -> Dict[str, Any]: return model - @classmethod - def default_scheme( - cls, - targets: Optional[List[str]] = None, - ): - if targets is None: - # default to quantizing all Linear layers - targets = ["Linear"] - - # by default, activations and weights are left unquantized - weights = None - input_activations = None - output_activations = None - - return cls( - targets=targets, - weights=weights, - input_activations=input_activations, - output_activations=output_activations, - ) - - """ Pre-Set Quantization Scheme Args """ diff --git a/tests/test_quantization/lifecycle/test_apply.py b/tests/test_quantization/lifecycle/test_apply.py index 4e9839b9..7268ca27 100644 --- a/tests/test_quantization/lifecycle/test_apply.py +++ b/tests/test_quantization/lifecycle/test_apply.py @@ -28,6 +28,7 @@ apply_quantization_status, ) from compressed_tensors.quantization.utils import iter_named_leaf_modules +from tests.testing_utils import requires_accelerate from transformers import AutoModelForCausalLM @@ -224,6 +225,7 @@ def get_sample_tinyllama_quant_config(status: str = "frozen"): return QuantizationConfig.parse_obj(config_dict) +@requires_accelerate() @pytest.mark.parametrize( "ignore,should_raise_warning", [ diff --git a/tests/test_quantization/test_quant_scheme.py b/tests/test_quantization/test_quant_scheme.py index 14ee5b72..0ea7f31f 100644 --- a/tests/test_quantization/test_quant_scheme.py +++ b/tests/test_quantization/test_quant_scheme.py @@ -53,7 +53,7 @@ def test_needs_targets(): def test_defaults(): targets = ["Linear"] - output = QuantizationScheme.default_scheme(targets=targets) + output = QuantizationScheme(targets=targets) assert output.weights is None assert output.input_activations is None assert output.output_activations is None diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 2e9be7cf..e446cad3 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -26,8 +26,29 @@ def compressed_tensors_config_available(): return False +def accelerate_availabe(): + try: + import accelerate # noqa: F401 + + return True + + except ImportError: + return False + + +_is_compressed_tensors_config_available = compressed_tensors_config_available() +_is_accelerate_available = accelerate_availabe() + + def requires_hf_quantizer(): return pytest.mark.skipif( - not compressed_tensors_config_available(), + not _is_compressed_tensors_config_available, reason="requires transformers>=4.45 to support CompressedTensorsHfQuantizer", ) + + +def requires_accelerate(): + return pytest.mark.skipif( + not _is_accelerate_available, + reason="requires accelerate", + )