From 3662e0ecfbcf8faa0e29f9ac613c3372d7a9e916 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Wed, 19 Jun 2024 09:52:44 -0600 Subject: [PATCH] Revert "Support calibrating kv cache scales (#17)" This reverts commit 0d40b99f0c3441fc45f756a1d84dcec9ce6cbf83. --- auto_fp8/config.py | 9 +- auto_fp8/modeling.py | 49 +++------ auto_fp8/quantize.py | 156 +++++++++-------------------- examples/example_static_kvcache.py | 25 ----- tests/test_auto_fp8.py | 79 ++------------- 5 files changed, 69 insertions(+), 249 deletions(-) delete mode 100644 examples/example_static_kvcache.py diff --git a/auto_fp8/config.py b/auto_fp8/config.py index 24c6200..7f8dd95 100644 --- a/auto_fp8/config.py +++ b/auto_fp8/config.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple +from typing import List class BaseQuantizeConfig: @@ -17,17 +17,13 @@ class BaseQuantizeConfig: regex style matching i.e. re.search(), for each Linear layer. By default, "re:.*lm_head" is included to ignore the embedding Linear layer usually at the end of decoder LLMs - kv_cache_quant_targets: Tuple of Linear module names to target for - calibration of the output scales for KV cache quantization. - Usually, these should be `("k_proj", "v_proj")`. """ def __init__( self, quant_method: str = "fp8", activation_scheme: str = "static", - ignore_patterns: List[str] = ["re:.*lm_head"], - kv_cache_quant_targets: Optional[Tuple[str]] = None, + ignore_patterns: List[str] = [], ): if quant_method != "fp8": raise ValueError("Only FP8 quantization is supported.") @@ -38,5 +34,4 @@ def __init__( self.quant_method = quant_method self.activation_scheme = activation_scheme self.ignore_patterns = ignore_patterns - self.kv_cache_quant_targets = kv_cache_quant_targets self.ignored_layers = [] diff --git a/auto_fp8/modeling.py b/auto_fp8/modeling.py index 04a9e71..340a598 100644 --- a/auto_fp8/modeling.py +++ b/auto_fp8/modeling.py @@ -1,5 +1,5 @@ import re -from typing import List, Optional, Tuple +from typing import List import torch from transformers import AutoModelForCausalLM @@ -27,16 +27,6 @@ def __init__( self.model, quantize_config.ignore_patterns ) - if quantize_config.kv_cache_quant_targets: - kv_cache_quant_layers = get_kv_cache_quant_layers( - self.model, quantize_config.kv_cache_quant_targets - ) - if len(kv_cache_quant_layers) == 0: - raise ValueError( - f"Could not find any kv cache layers using kv_cache_quant_targets={quantize_config.kv_cache_quant_targets}, please fix your argument." - ) - quantize_config.kv_cache_quant_layers = kv_cache_quant_layers - self.quantize_config = quantize_config @classmethod @@ -107,28 +97,26 @@ def skip(*args, **kwargs): return cls(model, quantize_config) - def quantize(self, calibration_tokens: Optional[torch.Tensor] = None): + def quantize(self, calibration_tokens): + def _prepare_calibration_data(calibration_tokens): + if hasattr(calibration_tokens, "input_ids"): + return calibration_tokens.input_ids + return calibration_tokens # Always quantize the weights as they do not require calibration data quantize_weights(self.model, self.quantize_config) if self.quantize_config.activation_scheme == "static": - assert ( - calibration_tokens is not None - ), "Calibration tokens required for activation quantization" - - - def _prepare_calibration_data(calibration_tokens): - if hasattr(calibration_tokens, "input_ids"): - return calibration_tokens.input_ids - return calibration_tokens - quantize_activations( self.model, self.quantize_config, _prepare_calibration_data(calibration_tokens), ) + # import copy + # for layer in self.model.model.layers: + # layer.self_attn.kv_scale = copy.deepcopy(layer.self_attn.k_proj.input_scale) + def save_quantized(self, save_dir): save_quantized_model( self.model, @@ -140,6 +128,9 @@ def save_quantized(self, save_dir): def get_layers_to_ignore(model, ignore_patterns) -> List[str]: ignored_layers = set() + # TODO: don't always ignore lm_head + ignore_patterns.append("re:.*lm_head") + for name, linear in model.named_modules(): if not isinstance(linear, torch.nn.Linear): continue @@ -157,17 +148,3 @@ def get_layers_to_ignore(model, ignore_patterns) -> List[str]: ignored_layers.add(name) return list(ignored_layers) - - -def get_kv_cache_quant_layers(model, kv_cache_quant_targets: Tuple[str]) -> List[str]: - kv_cache_quant_layers = [] - - for name, linear in model.named_modules(): - if not isinstance(linear, torch.nn.Linear): - continue - - for output_quant_target in kv_cache_quant_targets: - if name.endswith(output_quant_target): - kv_cache_quant_layers.append(name) - - return kv_cache_quant_layers diff --git a/auto_fp8/quantize.py b/auto_fp8/quantize.py index 38a4de6..4c1b580 100644 --- a/auto_fp8/quantize.py +++ b/auto_fp8/quantize.py @@ -1,6 +1,6 @@ import gc import re -from typing import Optional, Tuple +from typing import List, Tuple import copy import torch @@ -61,22 +61,14 @@ def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]: return qweight, scale -def static_per_tensor_quantize(tensor: torch.Tensor, inv_scale: float) -> torch.Tensor: - finfo = torch.finfo(torch.float8_e4m3fn) - qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max) - return qweight.to(torch.float8_e4m3fn) - - def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype): if A.numel() == 0: # Deal with empty tensors (triggeted by empty MoE experts) return torch.empty(size=(0, B.shape[0]), dtype=out_dtype, device=A.device) - - # TODO: Disable native fp8 gemm for now, always just dequantize - # native_fp8_support = ( - # torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) - # ) - native_fp8_support = False + + native_fp8_support = ( + torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) + ) if native_fp8_support: need_reshape = A.dim() == 3 if need_reshape: @@ -106,24 +98,25 @@ def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype): return output -# Class responsible for quantizing weights -class FP8DynamicLinear(torch.nn.Module): +class FP8StaticLinearQuantizer(torch.nn.Module): def __init__( - self, - weight: torch.Tensor, - weight_scale: torch.Tensor, - bias: torch.nn.Parameter, + self, qweight: torch.Tensor, weight_scale: torch.Tensor, bias: torch.Tensor ): super().__init__() - self.weight = torch.nn.Parameter(weight, requires_grad=False) + self.weight = torch.nn.Parameter(qweight, requires_grad=False) self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) + self.input_scale = None self.bias = bias def forward(self, x): - qinput, x_scale = per_tensor_quantize(x) + qinput, x_input_scale = per_tensor_quantize(x) + if self.input_scale is None: + self.input_scale = torch.nn.Parameter(x_input_scale) + elif x_input_scale > self.input_scale: + self.input_scale = torch.nn.Parameter(x_input_scale) output = fp8_gemm( A=qinput, - A_scale=x_scale, + A_scale=self.input_scale, B=self.weight, B_scale=self.weight_scale, bias=self.bias, @@ -132,29 +125,29 @@ def forward(self, x): return output -# Module responsible for taking already quantized weights, and recording input scales (and possibly output scales) using an activation observer -class FP8StaticLinearQuantizer(torch.nn.Module): +class FP8StaticLinear(torch.nn.Module): def __init__( self, - weight: torch.Tensor, + qweight: torch.Tensor, weight_scale: torch.Tensor, - bias: torch.nn.Parameter, - quantize_output: bool = False, + bias: torch.Tensor, + input_scale: float = 1.0, ): super().__init__() - self.weight = torch.nn.Parameter(weight, requires_grad=False) + self.weight = torch.nn.Parameter(qweight, requires_grad=False) self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) + self.input_scale = torch.nn.Parameter(input_scale, requires_grad=False) self.bias = bias - self.input_scale = None - self.output_scale = None - self.quantize_output = quantize_output + + def per_tensor_quantize( + self, tensor: torch.Tensor, inv_scale: float + ) -> torch.Tensor: + finfo = torch.finfo(torch.float8_e4m3fn) + qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max) + return qweight.to(torch.float8_e4m3fn) def forward(self, x): - qinput, x_input_scale = per_tensor_quantize(x) - if self.input_scale is None: - self.input_scale = torch.nn.Parameter(x_input_scale, requires_grad=False) - elif x_input_scale > self.input_scale: - self.input_scale = torch.nn.Parameter(x_input_scale, requires_grad=False) + qinput = self.per_tensor_quantize(x, inv_scale=self.input_scale) output = fp8_gemm( A=qinput, A_scale=self.input_scale, @@ -163,51 +156,26 @@ def forward(self, x): bias=self.bias, out_dtype=x.dtype, ) - - # Optionally, quantize output and record scale - if self.quantize_output: - qoutput, output_scale = per_tensor_quantize(output) - if self.output_scale is None: - self.output_scale = torch.nn.Parameter(output_scale, requires_grad=False) - elif output_scale > self.output_scale: - self.output_scale = torch.nn.Parameter(output_scale, requires_grad=False) - output = qoutput.to(output.dtype) * output_scale - return output -# Module responsible for representing the final checkpoint representation -class FP8StaticLinear(torch.nn.Module): - def __init__( - self, - weight: torch.nn.Parameter, - weight_scale: torch.nn.Parameter, - bias: torch.nn.Parameter, - input_scale: torch.nn.Parameter, - output_scale: Optional[torch.nn.Parameter] = None, - ): +class FP8DynamicLinear(torch.nn.Module): + def __init__(self, qweight: torch.Tensor, scale: torch.Tensor, bias: torch.Tensor): super().__init__() - self.weight = weight - self.weight_scale = weight_scale + self.weight = torch.nn.Parameter(qweight, requires_grad=False) + self.weight_scale = torch.nn.Parameter(scale, requires_grad=False) self.bias = bias - self.input_scale = input_scale - self.output_scale = output_scale def forward(self, x): - qinput = static_per_tensor_quantize(x, self.input_scale) + qinput, x_scale = per_tensor_quantize(x) output = fp8_gemm( A=qinput, - A_scale=self.input_scale, + A_scale=x_scale, B=self.weight, B_scale=self.weight_scale, bias=self.bias, out_dtype=x.dtype, ) - - if self.output_scale: - qoutput = static_per_tensor_quantize(output, self.output_scale) - output = qoutput.to(output.dtype) * self.output_scale - return output @@ -226,6 +194,7 @@ def replace_module(model: AutoModelForCausalLM, name: str, new_module: torch.nn. def quantize_weights( model: AutoModelForCausalLM, quantize_config: BaseQuantizeConfig, + ignored_layers: List[str] = [], ): named_modules = list(model.named_modules()) for name, linear in tqdm.tqdm(named_modules, desc="Quantizing weights"): @@ -234,11 +203,9 @@ def quantize_weights( or name in quantize_config.ignored_layers ): continue - quant_weight, weight_scale = per_tensor_quantize(linear.weight) + quant_weight, quant_scale = per_tensor_quantize(linear.weight) bias = copy.deepcopy(linear.bias) if linear.bias is not None else None - quant_linear = FP8DynamicLinear( - weight=quant_weight, weight_scale=weight_scale, bias=bias - ) + quant_linear = FP8DynamicLinear(quant_weight, quant_scale, bias) replace_module(model, name, quant_linear) del linear.weight del linear.bias @@ -250,6 +217,7 @@ def quantize_activations( model: AutoModelForCausalLM, quantize_config: BaseQuantizeConfig, calibration_tokens, + ignored_layers: List[str] = [], ): # Replace weight quantizer with a dynamic activation quantizer observer for name, dynamic_quant_linear in model.named_modules(): @@ -259,13 +227,9 @@ def quantize_activations( ): continue quantizer = FP8StaticLinearQuantizer( - weight=dynamic_quant_linear.weight, - weight_scale=dynamic_quant_linear.weight_scale, - bias=dynamic_quant_linear.bias, - quantize_output=( - hasattr(quantize_config, "kv_cache_quant_layers") - and name in quantize_config.kv_cache_quant_layers - ), + dynamic_quant_linear.weight, + dynamic_quant_linear.weight_scale, + dynamic_quant_linear.bias, ) replace_module(model, name, quantizer) del dynamic_quant_linear @@ -287,45 +251,21 @@ def quantize_activations( ): continue static_proj = FP8StaticLinear( - weight=quantizer.weight, - weight_scale=quantizer.weight_scale, - bias=quantizer.bias, - input_scale=quantizer.input_scale, - output_scale=quantizer.output_scale, + quantizer.weight, + quantizer.weight_scale, + quantizer.bias, + quantizer.input_scale, ) replace_module(model, name, static_proj) del quantizer cleanup_memory() - # Post-process step for kv cache scales to take the k/v module - # `output_scale` parameters, take the max of them, and store them in - # the parent attention module as `kv_scale` - # NOTE: if we want to switch to the `output_scale` representation, we can simply remove this block - if hasattr(quantize_config, "kv_cache_quant_layers"): - # Assumes that list is ordered such that [layer0.k_proj, layer0.v_proj, layer1.k_proj, layer1.v_proj, ...] - # so we make a list of tuples [(layer0.k_proj, layer0.v_proj), (layer1.k_proj, layer1.v_proj), ...] - kv_proj_pairs = zip(*[iter(quantize_config.kv_cache_quant_layers)]*2) - for k_proj_name, v_proj_name in kv_proj_pairs: - parent_module_name = ".".join(k_proj_name.split(".")[:-1]) - assert parent_module_name == ".".join(v_proj_name.split(".")[:-1]) - parent_module = dict(model.named_modules())[parent_module_name] - - k_proj = dict(model.named_modules())[k_proj_name] - v_proj = dict(model.named_modules())[v_proj_name] - - kv_scale = max(k_proj.output_scale, v_proj.output_scale) - parent_module.kv_scale = torch.nn.Parameter(kv_scale, requires_grad=False) - - # Remove output_scale from k_proj and v_proj - k_proj.output_scale = None - v_proj.output_scale = None - cleanup_memory() - def save_quantized_model( model: AutoModelForCausalLM, quant_config: BaseQuantizeConfig, save_dir: str, + ignored_layers: List[str] = [], ): print(model) print(f"Saving the model to {save_dir}") @@ -336,8 +276,6 @@ def save_quantized_model( "ignored_layers": quant_config.ignored_layers, } } - if hasattr(quant_config, "kv_cache_quant_layers"): - static_q_dict["quantization_config"]["kv_cache_scheme"] = "static" model.config.update(static_q_dict) model.save_pretrained(save_dir) tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path) diff --git a/examples/example_static_kvcache.py b/examples/example_static_kvcache.py deleted file mode 100644 index 118bad5..0000000 --- a/examples/example_static_kvcache.py +++ /dev/null @@ -1,25 +0,0 @@ -from datasets import load_dataset -from transformers import AutoTokenizer - -from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig - -pretrained_model_dir = "meta-llama/Meta-Llama-3-8B-Instruct" -quantized_model_dir = "Meta-Llama-3-8B-Instruct-FP8-KV" - -tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True) -tokenizer.pad_token = tokenizer.eos_token - -ds = load_dataset("mgoin/ultrachat_2k", split="train_sft").select(range(512)) -examples = [tokenizer.apply_chat_template(batch["messages"], tokenize=False) for batch in ds] -examples = tokenizer(examples, padding=True, truncation=True, return_tensors="pt").to("cuda") - -quantize_config = BaseQuantizeConfig( - quant_method="fp8", - activation_scheme="static", - ignore_patterns=["re:.*lm_head"], - kv_cache_quant_targets=("k_proj", "v_proj"), -) - -model = AutoFP8ForCausalLM.from_pretrained(pretrained_model_dir, quantize_config) -model.quantize(examples) -model.save_quantized(quantized_model_dir) diff --git a/tests/test_auto_fp8.py b/tests/test_auto_fp8.py index 6045d84..51db3c1 100644 --- a/tests/test_auto_fp8.py +++ b/tests/test_auto_fp8.py @@ -1,43 +1,14 @@ import os import shutil -import pytest -import safetensors.torch from transformers import AutoTokenizer from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig -MODELS = [ - ("facebook/opt-125m", 160), - ("Qwen/Qwen2-0.5B-Instruct", 620), -] -@pytest.mark.parametrize("model_id,target_size", MODELS) -def test_dynamic_quantization(model_id, target_size): - quantized_model_dir = model_id.split("/")[-1] + "-fp8-dynamic" - - quantize_config = BaseQuantizeConfig( - quant_method="fp8", activation_scheme="dynamic" - ) - - model = AutoFP8ForCausalLM.from_pretrained(model_id, quantize_config) - model.model.to("cpu") - - model.quantize() - model.save_quantized(quantized_model_dir) - - # Measure checkpoint size and cleanup - model_size = os.path.getsize(f"{quantized_model_dir}/model.safetensors") - shutil.rmtree(quantized_model_dir) - - # We expect the quantized model to be a certain size - target_size = target_size * (1024 * 1024) - assert model_size < target_size - - -@pytest.mark.parametrize("model_id,target_size", MODELS) -def test_static_quantization(model_id, target_size): - quantized_model_dir = model_id.split("/")[-1] + "-fp8-static" +def test_quantization(): + model_id = "facebook/opt-125m" + quantized_model_dir = "opt-125m-fp8" tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True) examples = ["auto-fp8 is an easy-to-use model quantization library"] @@ -45,54 +16,18 @@ def test_static_quantization(model_id, target_size): quantize_config = BaseQuantizeConfig(quant_method="fp8", activation_scheme="static") - model = AutoFP8ForCausalLM.from_pretrained(model_id, quantize_config) - model.model.to("cpu") - - model.quantize(examples) - model.save_quantized(quantized_model_dir) - - # Measure checkpoint size and cleanup - model_size = os.path.getsize(f"{quantized_model_dir}/model.safetensors") - shutil.rmtree(quantized_model_dir) - - # We expect the quantized model to be a certain size - target_size = target_size * (1024 * 1024) - assert model_size < target_size - -@pytest.mark.parametrize("model_id,target_size", MODELS) -def test_kv_cache_static_quantization(model_id, target_size): - quantized_model_dir = model_id.split("/")[-1] + "-fp8-static-kv" - - tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True) - examples = ["auto-fp8 is an easy-to-use model quantization library"] - examples = tokenizer(examples, return_tensors="pt") - - quantize_config = BaseQuantizeConfig( - quant_method="fp8", - activation_scheme="static", - kv_cache_quant_targets=("k_proj", "v_proj"), + model = AutoFP8ForCausalLM.from_pretrained( + model_id, quantize_config=quantize_config ) - - model = AutoFP8ForCausalLM.from_pretrained(model_id, quantize_config) model.model.to("cpu") model.quantize(examples) model.save_quantized(quantized_model_dir) - tensors = safetensors.torch.load_file(f"{quantized_model_dir}/model.safetensors") - proj_linear_count = 0 - kv_scale_count = 0 - for name, _ in tensors.items(): - if name.endswith("k_proj.weight") or name.endswith("v_proj.weight"): - proj_linear_count += 1 - if name.endswith("kv_scale"): - kv_scale_count += 1 - assert proj_linear_count // 2 == kv_scale_count - # Measure checkpoint size and cleanup model_size = os.path.getsize(f"{quantized_model_dir}/model.safetensors") shutil.rmtree(quantized_model_dir) - # We expect the quantized model to be a certain size - target_size = target_size * (1024 * 1024) + # We expect the model to be < 160MB + target_size = 160 * (1024 * 1024) assert model_size < target_size