From 8ea96175269f79b6646e49202bb3aea8c76b787c Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Wed, 18 Dec 2024 14:42:52 -0500 Subject: [PATCH 1/2] Fix SmoothQuant offload bug (#978) * fix offload Signed-off-by: Dipika * fix smoothquant offload bug * remove logtime --------- Signed-off-by: Dipika --- src/llmcompressor/modifiers/smoothquant/base.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/llmcompressor/modifiers/smoothquant/base.py b/src/llmcompressor/modifiers/smoothquant/base.py index f4117e31d..9381348b1 100644 --- a/src/llmcompressor/modifiers/smoothquant/base.py +++ b/src/llmcompressor/modifiers/smoothquant/base.py @@ -2,6 +2,7 @@ from typing import Callable, Dict, List, Optional, Tuple import torch +from compressed_tensors.utils.offload import is_module_offloaded from loguru import logger from torch.nn import Module @@ -282,6 +283,10 @@ def _apply_smoothing(self, model: Module): @torch.no_grad() def smooth(module): + offloaded = is_module_offloaded(module) + if offloaded: + module._hf_hook.pre_forward(module) + if module in balance_layers: module.weight.mul_(scales.view(1, -1)) elif module == smooth_layer: @@ -292,6 +297,9 @@ def smooth(module): if hasattr(module, "bias") and module.bias is not None: module.bias.div_(scales) + if offloaded: + module._hf_hook.post_forward(module, None) + parent = get_fsdp_parent(mapping.smooth_name, model) if parent is not None: parent.apply(smooth) @@ -318,8 +326,16 @@ def _calculate_smoothing_scales( # get the channel-wise dynamic range for each layer to be balanced weight_scales = [] for layer in balance_layers: + offloaded = is_module_offloaded(layer) + if offloaded: + layer._hf_hook.pre_forward(layer) + scale = layer.weight.abs().max(dim=0, keepdim=True)[0] weight_scales.append(scale) + + if offloaded: + layer._hf_hook.post_forward(layer, None) + weight_scales = 2.0 * torch.cat(weight_scales, dim=0).max(dim=0)[0] # calculate the amount of smoothing to apply @@ -329,4 +345,5 @@ def _calculate_smoothing_scales( 1 - self.smoothing_strength ) scales = torch.where(weight_scales > 0.0, scales, activation_scales) + return scales From 7366a2d75c2b1aa61b9a3097fa7e1365de3382f6 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Wed, 18 Dec 2024 21:31:54 -0500 Subject: [PATCH 2/2] Add LM Eval Configs (#980) --- .../vLLM/lm_eval_configs/fp8_dynamic_per_token.yaml | 4 ++-- .../vLLM/lm_eval_configs/fp8_static_per_tensor.yaml | 10 ++++++++++ .../lm_eval_configs/int8_w8a8_dynamic_per_token.yaml | 11 +++++++---- .../vLLM/lm_eval_configs/w4a16_actorder_weight.yaml | 11 +++++++++++ .../e2e/vLLM/lm_eval_configs/w4a16_grouped_quant.yaml | 11 +++++++++++ .../recipe_int8_channel_weight_dynamic_per_token.yaml | 11 +++++++++++ ...ipe_int8_channel_weight_static_per_tensor_act.yaml | 2 +- tests/e2e/vLLM/test_lmeval.py | 4 ++-- 8 files changed, 55 insertions(+), 9 deletions(-) create mode 100644 tests/e2e/vLLM/lm_eval_configs/fp8_static_per_tensor.yaml create mode 100644 tests/e2e/vLLM/lm_eval_configs/w4a16_actorder_weight.yaml create mode 100644 tests/e2e/vLLM/lm_eval_configs/w4a16_grouped_quant.yaml create mode 100644 tests/e2e/vLLM/recipes/INT8/recipe_int8_channel_weight_dynamic_per_token.yaml diff --git a/tests/e2e/vLLM/lm_eval_configs/fp8_dynamic_per_token.yaml b/tests/e2e/vLLM/lm_eval_configs/fp8_dynamic_per_token.yaml index 461353770..fc610bae9 100644 --- a/tests/e2e/vLLM/lm_eval_configs/fp8_dynamic_per_token.yaml +++ b/tests/e2e/vLLM/lm_eval_configs/fp8_dynamic_per_token.yaml @@ -4,5 +4,5 @@ scheme: FP8_DYNAMIC num_fewshot: 5 limit: 1000 task: "gsm8k" -exact_match,flexible-extract: 0.753 -exact_match,strict-match: 0.753 +exact_match,flexible-extract: 0.75 +exact_match,strict-match: 0.75 diff --git a/tests/e2e/vLLM/lm_eval_configs/fp8_static_per_tensor.yaml b/tests/e2e/vLLM/lm_eval_configs/fp8_static_per_tensor.yaml new file mode 100644 index 000000000..0b6d42a46 --- /dev/null +++ b/tests/e2e/vLLM/lm_eval_configs/fp8_static_per_tensor.yaml @@ -0,0 +1,10 @@ +cadence: "weekly" +model: meta-llama/Meta-Llama-3-8B-Instruct +scheme: FP8 +num_fewshot: 5 +limit: 1000 +task: "gsm8k" +dataset_id: HuggingFaceH4/ultrachat_200k +dataset_split: train_sft +exact_match,flexible-extract: 0.75 +exact_match,strict-match: 0.75 diff --git a/tests/e2e/vLLM/lm_eval_configs/int8_w8a8_dynamic_per_token.yaml b/tests/e2e/vLLM/lm_eval_configs/int8_w8a8_dynamic_per_token.yaml index b16f5575a..446ca1e7f 100644 --- a/tests/e2e/vLLM/lm_eval_configs/int8_w8a8_dynamic_per_token.yaml +++ b/tests/e2e/vLLM/lm_eval_configs/int8_w8a8_dynamic_per_token.yaml @@ -1,8 +1,11 @@ cadence: "weekly" model: meta-llama/Meta-Llama-3-8B-Instruct -scheme: INT8 +scheme: INT8_dyn_per_token +recipe: tests/e2e/vLLM/recipes/INT8/recipe_int8_channel_weight_dynamic_per_token.yaml num_fewshot: 5 -limit: 250 +limit: 1000 task: "gsm8k" -exact_match,flexible-extract: 0.728 -exact_match,strict-match: 0.728 +dataset_id: HuggingFaceH4/ultrachat_200k +dataset_split: train_sft +exact_match,flexible-extract: 0.77 +exact_match,strict-match: 0.76 diff --git a/tests/e2e/vLLM/lm_eval_configs/w4a16_actorder_weight.yaml b/tests/e2e/vLLM/lm_eval_configs/w4a16_actorder_weight.yaml new file mode 100644 index 000000000..ca82bb44f --- /dev/null +++ b/tests/e2e/vLLM/lm_eval_configs/w4a16_actorder_weight.yaml @@ -0,0 +1,11 @@ +cadence: "weekly" +model: meta-llama/Meta-Llama-3-8B-Instruct +recipe: tests/e2e/vLLM/recipes/actorder/recipe_w4a16_actorder_weight.yaml +num_fewshot: 5 +limit: 1000 +task: "gsm8k" +dataset_id: HuggingFaceH4/ultrachat_200k +dataset_split: train_sft +exact_match,flexible-extract: 0.72 +exact_match,strict-match: 0.72 +scheme: W4A16_actorder_group \ No newline at end of file diff --git a/tests/e2e/vLLM/lm_eval_configs/w4a16_grouped_quant.yaml b/tests/e2e/vLLM/lm_eval_configs/w4a16_grouped_quant.yaml new file mode 100644 index 000000000..a4c7b6244 --- /dev/null +++ b/tests/e2e/vLLM/lm_eval_configs/w4a16_grouped_quant.yaml @@ -0,0 +1,11 @@ +cadence: "weekly" +model: meta-llama/Meta-Llama-3-8B-Instruct +num_fewshot: 5 +limit: 1000 +task: "gsm8k" +exact_match,flexible-extract: 0.72 +exact_match,strict-match: 0.72 +scheme: W4A16 +dataset_id: HuggingFaceH4/ultrachat_200k +dataset_split: train_sft +quant_type: "GPTQ" \ No newline at end of file diff --git a/tests/e2e/vLLM/recipes/INT8/recipe_int8_channel_weight_dynamic_per_token.yaml b/tests/e2e/vLLM/recipes/INT8/recipe_int8_channel_weight_dynamic_per_token.yaml new file mode 100644 index 000000000..367437e5a --- /dev/null +++ b/tests/e2e/vLLM/recipes/INT8/recipe_int8_channel_weight_dynamic_per_token.yaml @@ -0,0 +1,11 @@ +quant_stage: + quant_modifiers: + SmoothQuantModifier: + smoothing_strength: 0.8 + GPTQModifier: + ignore: [lm_head] + config_groups: + group_0: + weights: {num_bits: 8, type: int, symmetric: true, strategy: channel} + input_activations: {num_bits: 8, type: int, symmetric: true, strategy: token, dynamic: true} + targets: [Linear] diff --git a/tests/e2e/vLLM/recipes/INT8/recipe_int8_channel_weight_static_per_tensor_act.yaml b/tests/e2e/vLLM/recipes/INT8/recipe_int8_channel_weight_static_per_tensor_act.yaml index 2c0094f88..9703872bc 100644 --- a/tests/e2e/vLLM/recipes/INT8/recipe_int8_channel_weight_static_per_tensor_act.yaml +++ b/tests/e2e/vLLM/recipes/INT8/recipe_int8_channel_weight_static_per_tensor_act.yaml @@ -2,7 +2,7 @@ quant_stage: quant_modifiers: SmoothQuantModifier: smoothing_strength: 0.8 - QuantizationModifier: + GPTQModifier: ignore: [lm_head] config_groups: group_0: diff --git a/tests/e2e/vLLM/test_lmeval.py b/tests/e2e/vLLM/test_lmeval.py index f77bda983..4e11123a5 100644 --- a/tests/e2e/vLLM/test_lmeval.py +++ b/tests/e2e/vLLM/test_lmeval.py @@ -68,7 +68,7 @@ def set_up(self): logger.info(self.scheme) self.device = "cuda:0" - self.num_calibration_samples = 256 + self.num_calibration_samples = 512 self.max_seq_length = 2048 def test_lm_eval(self): @@ -104,7 +104,7 @@ def test_lm_eval(self): logger.info("================= Running LM Eval ======================") - model_args = f"pretrained={self.save_dir}" + model_args = f"pretrained={self.save_dir},add_bos_token=True" results = lm_eval.simple_evaluate( model="hf", model_args=model_args,