Skip to content

Commit

Permalink
Update target match conditions; make public (vllm-project#44)
Browse files Browse the repository at this point in the history
* update condition; make function public

* style

* default update
  • Loading branch information
dsikka authored Apr 30, 2024
1 parent 0a981af commit b0668e2
Showing 1 changed file with 14 additions and 9 deletions.
23 changes: 14 additions & 9 deletions src/compressed_tensors/quantization/lifecycle/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
"load_pretrained_quantization",
"apply_quantization_config",
"apply_quantization_status",
"find_first_name_or_class_match",
]

from compressed_tensors.quantization.utils.helpers import is_module_quantized
Expand Down Expand Up @@ -99,9 +100,9 @@ def apply_quantization_config(model: Module, config: QuantizationConfig):

# mark appropriate layers for quantization by setting their quantization schemes
for name, submodule in iter_named_leaf_modules(model):
if _find_first_name_or_class_match(name, submodule, config.ignore):
if find_first_name_or_class_match(name, submodule, config.ignore):
continue # layer matches ignore list, continue
target = _find_first_name_or_class_match(name, submodule, target_to_scheme)
target = find_first_name_or_class_match(name, submodule, target_to_scheme)
if target is not None:
# target matched - add layer and scheme to target list
submodule.quantization_scheme = target_to_scheme[target]
Expand All @@ -125,27 +126,31 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
model.apply(freeze_module_quantization)


def _find_first_name_or_class_match(
name: str,
module: Module,
targets: Iterable[str],
def find_first_name_or_class_match(
name: str, module: Module, targets: Iterable[str], check_contains: bool = False
) -> Optional[str]:
# first element of targets that matches the given name
# if no name matches returns first target that matches the class name
# returns None otherwise
return _find_first_match(name, targets) or _find_first_match(
module.__class__.__name__, targets
module.__class__.__name__, targets, check_contains
)


def _find_first_match(value: str, targets: Iterable[str]) -> Optional[str]:
def _find_first_match(
value: str, targets: Iterable[str], check_contains: bool = False
) -> Optional[str]:
# returns first element of target that matches value either
# exactly or as a regex after 're:'
# exactly or as a regex after 're:'. if check_contains is set to True,
# additionally checks if the target string is contained with value.
for target in targets:
if target.startswith("re:"):
pattern = target[3:]
if re.match(pattern, value):
return target
elif check_contains:
if target.lower() in value.lower():
return target
elif target == value:
return target
return None
Expand Down

0 comments on commit b0668e2

Please sign in to comment.