Skip to content

Commit

Permalink
Merge pull request #20 from vllm-project/dbogunowicz-patch-1
Browse files Browse the repository at this point in the history
[MOE Quantization] Warn against "undercalibrated" modules
  • Loading branch information
dbogunowicz authored Jul 11, 2024
2 parents 2eead1b + bc20726 commit 7d9c643
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 0 deletions.
40 changes: 40 additions & 0 deletions src/llmcompressor/modifiers/quantization/quantization/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
preset_name_to_scheme,
set_module_for_calibration,
)
from compressed_tensors.quantization.observers.helpers import get_observer_token_count
from loguru import logger
from pydantic import Field
from torch.nn import Module
Expand Down Expand Up @@ -76,6 +77,9 @@ def on_initialize(self, state: State, **kwargs) -> bool:
if self.calculate_start() == -1: # one-shot
module.apply(set_module_for_calibration)
self._calibrate_if_possible(module)
self._check_token_distribution(
module, threshold=kwargs.get("min_tokens_per_module")
)
module.apply(freeze_module_quantization)

return True
Expand Down Expand Up @@ -201,3 +205,39 @@ def _calibrate(self, module: Module):

if module_training:
module.train()

def _check_token_distribution(
self, model: Module, threshold: Optional[float] = None
):
"""
A helper function that warns when a module has seen
fewer than threshold % of all the tokens throughout
the calibration process.
Checks are only triggered if threshold is not None.
:param model: the model to validate
:param threshold: the minimum percentage of tokens
(out of all the tokens in a batch) a module should
receive during calibration
"""
if threshold is None:
logger.debug("Skipping token distribution check. threshold is None.")
return

all_tokens = self.calibration_dataloader_.dataset["input_ids"]
total_token_count = sum(len(sample) for sample in all_tokens)
counter = get_observer_token_count(model)
for module_name, token_count in counter.items():
if token_count is None:
# the module has not been observed
# or its token_count is not being recorded
# by the observer (refer to the observer's
# implementation in the source code)
continue
if token_count / total_token_count < threshold:
logger.warning(
f"The module_name: {module_name} "
f"received less than {int(threshold * 100)}% "
"of calibration batch tokens "
f"({token_count}/{total_token_count} tokens). "
"This could result may harm the quantization quality."
)
11 changes: 11 additions & 0 deletions src/llmcompressor/transformers/finetune/data/data_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,14 @@ class DataTrainingArguments(CustomDataTrainingArguments):
),
},
)
min_tokens_per_module: Optional[float] = field(
default=0.2,
metadata={
"help": (
"The minimum percentage of tokens (out of the total number) "
"that the module should 'receive' throughout the forward "
"pass of the calibration. If a module receives fewer tokens, "
"a warning will be logged."
),
},
)
3 changes: 3 additions & 0 deletions src/llmcompressor/transformers/finetune/session_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ def __init__(
if self.is_fsdp_enabled:
self._prepare_model_for_fsdp()

self.min_tokens_per_module = data_args.min_tokens_per_module

def initialize_session(
self,
epoch: float,
Expand Down Expand Up @@ -402,6 +404,7 @@ def one_shot(self, calib_data: DataLoader, stage: Optional[str] = None):
start=-1,
copy_data=False,
accelerator=self.accelerator,
min_tokens_per_module=self.min_tokens_per_module,
)

# log model sparsity
Expand Down

0 comments on commit 7d9c643

Please sign in to comment.