From b0668e2a4478b53f6134d10b65709493ca977947 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 30 Apr 2024 14:55:29 -0400 Subject: [PATCH] Update target match conditions; make public (#44) * update condition; make function public * style * default update --- .../quantization/lifecycle/apply.py | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index ab66cfe84..4c601d076 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -36,6 +36,7 @@ "load_pretrained_quantization", "apply_quantization_config", "apply_quantization_status", + "find_first_name_or_class_match", ] from compressed_tensors.quantization.utils.helpers import is_module_quantized @@ -99,9 +100,9 @@ def apply_quantization_config(model: Module, config: QuantizationConfig): # mark appropriate layers for quantization by setting their quantization schemes for name, submodule in iter_named_leaf_modules(model): - if _find_first_name_or_class_match(name, submodule, config.ignore): + if find_first_name_or_class_match(name, submodule, config.ignore): continue # layer matches ignore list, continue - target = _find_first_name_or_class_match(name, submodule, target_to_scheme) + target = find_first_name_or_class_match(name, submodule, target_to_scheme) if target is not None: # target matched - add layer and scheme to target list submodule.quantization_scheme = target_to_scheme[target] @@ -125,27 +126,31 @@ def apply_quantization_status(model: Module, status: QuantizationStatus): model.apply(freeze_module_quantization) -def _find_first_name_or_class_match( - name: str, - module: Module, - targets: Iterable[str], +def find_first_name_or_class_match( + name: str, module: Module, targets: Iterable[str], check_contains: bool = False ) -> Optional[str]: # first element of targets that matches the given name # if no name matches returns first target that matches the class name # returns None otherwise return _find_first_match(name, targets) or _find_first_match( - module.__class__.__name__, targets + module.__class__.__name__, targets, check_contains ) -def _find_first_match(value: str, targets: Iterable[str]) -> Optional[str]: +def _find_first_match( + value: str, targets: Iterable[str], check_contains: bool = False +) -> Optional[str]: # returns first element of target that matches value either - # exactly or as a regex after 're:' + # exactly or as a regex after 're:'. if check_contains is set to True, + # additionally checks if the target string is contained with value. for target in targets: if target.startswith("re:"): pattern = target[3:] if re.match(pattern, value): return target + elif check_contains: + if target.lower() in value.lower(): + return target elif target == value: return target return None