Skip to content

Commit

Permalink
Merge branch 'main' into kylesayrs/quant-scheme-validation
Browse files Browse the repository at this point in the history
  • Loading branch information
kylesayrs authored Nov 23, 2024
2 parents a2218d0 + a26c03a commit 6f67abd
Show file tree
Hide file tree
Showing 8 changed files with 50 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import torch
import transformers
from compressed_tensors.base import (
COMPRESSION_CONFIG_NAME,
COMPRESSION_VERSION_NAME,
QUANTIZATION_CONFIG_NAME,
QUANTIZATION_METHOD_NAME,
Expand All @@ -39,6 +38,7 @@
apply_quantization_config,
load_pretrained_quantization,
)
from compressed_tensors.quantization.quant_args import QuantizationArgs
from compressed_tensors.quantization.utils import (
is_module_quantized,
iter_named_leaf_modules,
Expand Down Expand Up @@ -103,12 +103,14 @@ def from_pretrained(
:return: compressor for the configs, or None if model is not compressed
"""
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
compression_config = getattr(config, COMPRESSION_CONFIG_NAME, None)
compression_config = getattr(config, QUANTIZATION_CONFIG_NAME, None)

return cls.from_compression_config(compression_config)

@classmethod
def from_compression_config(
cls, compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"]
cls,
compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"],
):
"""
:param compression_config:
Expand Down Expand Up @@ -265,7 +267,11 @@ def compress(
state_dict = model.state_dict()

compressed_state_dict = state_dict
quantized_modules_to_args = map_modules_to_quant_args(model)

quantized_modules_to_args: Dict[
str, QuantizationArgs
] = map_modules_to_quant_args(model)

if self.quantization_compressor is not None:
compressed_state_dict = self.quantization_compressor.compress(
state_dict, names_to_scheme=quantized_modules_to_args
Expand Down Expand Up @@ -369,7 +375,13 @@ def _replace_weights(self, dense_weight_generator, model):
update_parameter_data(module, data, param_name)


def map_modules_to_quant_args(model: Module) -> Dict:
def map_modules_to_quant_args(model: Module) -> Dict[str, QuantizationArgs]:
"""
Given a pytorch model, map out the submodule name (usually linear layers)
to the QuantizationArgs
:param model: pytorch model
"""
quantized_modules_to_args = {}
for name, submodule in iter_named_leaf_modules(model):
if is_module_quantized(submodule):
Expand Down
4 changes: 3 additions & 1 deletion src/compressed_tensors/linear/compressed_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, Tuple

import torch
from compressed_tensors.compressors.base import BaseCompressor
from compressed_tensors.quantization import (
Expand Down Expand Up @@ -53,7 +55,7 @@ def from_linear(
)

# get the shape and dtype of compressed parameters
compression_params = module.compressor.compression_param_info(
compression_params: Dict[str, Tuple] = module.compressor.compression_param_info(
module.weight.shape, quantization_scheme.weights
)

Expand Down
3 changes: 2 additions & 1 deletion src/compressed_tensors/quantization/lifecycle/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ def apply_quantization_config(
model: Module, config: Union[QuantizationConfig, None], run_compressed: bool = False
) -> OrderedDict:
"""
Initializes the model for quantization in-place based on the given config
Initializes the model for quantization in-place based on the given config.
Optionally coverts quantizable modules to compressed_linear modules
:param model: model to apply quantization config to
:param config: quantization config
Expand Down
4 changes: 2 additions & 2 deletions src/compressed_tensors/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,9 @@ class QuantizationConfig(BaseModel):
`k_proj` and `v_proj` in their names. If this is not the case
and kv_cache_scheme != None, the quantization of kv cache will fail
:global_compression_ratio: optional informational config to report the model
compression ratio acheived by the quantization config
compression ratio acheived by the quantization config
:ignore: optional list of layers to ignore from config_groups. Layers in this list
are not quantized even if they match up with a target in config_groups
are not quantized even if they match up with a target in config_groups
"""

config_groups: Dict[str, Union[QuantizationScheme, List[str]]]
Expand Down
24 changes: 1 addition & 23 deletions src/compressed_tensors/quantization/quant_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class QuantizationScheme(BaseModel):
of modules should be quantized
:param targets: list of modules to apply the QuantizationArgs to, can be layer
names, layer types or a regular expression
names, layer types or a regular expression, typically ["Linear"]
:param weights: quantization config for layer weights
:param input_activations: quantization config for layer inputs
:param output_activations: quantization config for layer outputs
Expand All @@ -62,28 +62,6 @@ def validate_model_after(model: "QuantizationArgs") -> Dict[str, Any]:

return model

@classmethod
def default_scheme(
cls,
targets: Optional[List[str]] = None,
):
if targets is None:
# default to quantizing all Linear layers
targets = ["Linear"]

# by default, activations and weights are left unquantized
weights = None
input_activations = None
output_activations = None

return cls(
targets=targets,
weights=weights,
input_activations=input_activations,
output_activations=output_activations,
)


"""
Pre-Set Quantization Scheme Args
"""
Expand Down
2 changes: 2 additions & 0 deletions tests/test_quantization/lifecycle/test_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
apply_quantization_status,
)
from compressed_tensors.quantization.utils import iter_named_leaf_modules
from tests.testing_utils import requires_accelerate
from transformers import AutoModelForCausalLM


Expand Down Expand Up @@ -224,6 +225,7 @@ def get_sample_tinyllama_quant_config(status: str = "frozen"):
return QuantizationConfig.parse_obj(config_dict)


@requires_accelerate()
@pytest.mark.parametrize(
"ignore,should_raise_warning",
[
Expand Down
2 changes: 1 addition & 1 deletion tests/test_quantization/test_quant_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_needs_targets():

def test_defaults():
targets = ["Linear"]
output = QuantizationScheme.default_scheme(targets=targets)
output = QuantizationScheme(targets=targets)
assert output.weights is None
assert output.input_activations is None
assert output.output_activations is None
23 changes: 22 additions & 1 deletion tests/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,29 @@ def compressed_tensors_config_available():
return False


def accelerate_availabe():
try:
import accelerate # noqa: F401

return True

except ImportError:
return False


_is_compressed_tensors_config_available = compressed_tensors_config_available()
_is_accelerate_available = accelerate_availabe()


def requires_hf_quantizer():
return pytest.mark.skipif(
not compressed_tensors_config_available(),
not _is_compressed_tensors_config_available,
reason="requires transformers>=4.45 to support CompressedTensorsHfQuantizer",
)


def requires_accelerate():
return pytest.mark.skipif(
not _is_accelerate_available,
reason="requires accelerate",
)

0 comments on commit 6f67abd

Please sign in to comment.