diff --git a/src/llmcompressor/modifiers/quantization/quantization/base.py b/src/llmcompressor/modifiers/quantization/quantization/base.py index d451a63f0..434f6f2d8 100644 --- a/src/llmcompressor/modifiers/quantization/quantization/base.py +++ b/src/llmcompressor/modifiers/quantization/quantization/base.py @@ -10,6 +10,7 @@ preset_name_to_scheme, set_module_for_calibration, ) +from compressed_tensors.quantization.observers.helpers import get_observer_token_count from loguru import logger from pydantic import Field from torch.nn import Module @@ -76,6 +77,9 @@ def on_initialize(self, state: State, **kwargs) -> bool: if self.calculate_start() == -1: # one-shot module.apply(set_module_for_calibration) self._calibrate_if_possible(module) + self._check_token_distribution( + module, threshold=kwargs.get("min_tokens_per_module") + ) module.apply(freeze_module_quantization) return True @@ -201,3 +205,39 @@ def _calibrate(self, module: Module): if module_training: module.train() + + def _check_token_distribution( + self, model: Module, threshold: Optional[float] = None + ): + """ + A helper function that warns when a module has seen + fewer than threshold % of all the tokens throughout + the calibration process. + Checks are only triggered if threshold is not None. + :param model: the model to validate + :param threshold: the minimum percentage of tokens + (out of all the tokens in a batch) a module should + receive during calibration + """ + if threshold is None: + logger.debug("Skipping token distribution check. threshold is None.") + return + + all_tokens = self.calibration_dataloader_.dataset["input_ids"] + total_token_count = sum(len(sample) for sample in all_tokens) + counter = get_observer_token_count(model) + for module_name, token_count in counter.items(): + if token_count is None: + # the module has not been observed + # or its token_count is not being recorded + # by the observer (refer to the observer's + # implementation in the source code) + continue + if token_count / total_token_count < threshold: + logger.warning( + f"The module_name: {module_name} " + f"received less than {int(threshold * 100)}% " + "of calibration batch tokens " + f"({token_count}/{total_token_count} tokens). " + "This could result may harm the quantization quality." + ) diff --git a/src/llmcompressor/transformers/finetune/data/data_args.py b/src/llmcompressor/transformers/finetune/data/data_args.py index b6e7fc555..6bd362215 100644 --- a/src/llmcompressor/transformers/finetune/data/data_args.py +++ b/src/llmcompressor/transformers/finetune/data/data_args.py @@ -153,3 +153,14 @@ class DataTrainingArguments(CustomDataTrainingArguments): ), }, ) + min_tokens_per_module: Optional[float] = field( + default=0.2, + metadata={ + "help": ( + "The minimum percentage of tokens (out of the total number) " + "that the module should 'receive' throughout the forward " + "pass of the calibration. If a module receives fewer tokens, " + "a warning will be logged." + ), + }, + ) diff --git a/src/llmcompressor/transformers/finetune/session_mixin.py b/src/llmcompressor/transformers/finetune/session_mixin.py index abfc82614..4bb442ded 100644 --- a/src/llmcompressor/transformers/finetune/session_mixin.py +++ b/src/llmcompressor/transformers/finetune/session_mixin.py @@ -107,6 +107,8 @@ def __init__( if self.is_fsdp_enabled: self._prepare_model_for_fsdp() + self.min_tokens_per_module = data_args.min_tokens_per_module + def initialize_session( self, epoch: float, @@ -402,6 +404,7 @@ def one_shot(self, calib_data: DataLoader, stage: Optional[str] = None): start=-1, copy_data=False, accelerator=self.accelerator, + min_tokens_per_module=self.min_tokens_per_module, ) # log model sparsity