Skip to content

Commit

Permalink
Revert "Support calibrating kv cache scales (#17)"
Browse files Browse the repository at this point in the history
This reverts commit 0d40b99.
  • Loading branch information
mgoin authored Jun 19, 2024
1 parent 2a9330c commit 3662e0e
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 249 deletions.
9 changes: 2 additions & 7 deletions auto_fp8/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Tuple
from typing import List


class BaseQuantizeConfig:
Expand All @@ -17,17 +17,13 @@ class BaseQuantizeConfig:
regex style matching i.e. re.search(), for each Linear layer.
By default, "re:.*lm_head" is included to ignore the embedding
Linear layer usually at the end of decoder LLMs
kv_cache_quant_targets: Tuple of Linear module names to target for
calibration of the output scales for KV cache quantization.
Usually, these should be `("k_proj", "v_proj")`.
"""

def __init__(
self,
quant_method: str = "fp8",
activation_scheme: str = "static",
ignore_patterns: List[str] = ["re:.*lm_head"],
kv_cache_quant_targets: Optional[Tuple[str]] = None,
ignore_patterns: List[str] = [],
):
if quant_method != "fp8":
raise ValueError("Only FP8 quantization is supported.")
Expand All @@ -38,5 +34,4 @@ def __init__(
self.quant_method = quant_method
self.activation_scheme = activation_scheme
self.ignore_patterns = ignore_patterns
self.kv_cache_quant_targets = kv_cache_quant_targets
self.ignored_layers = []
49 changes: 13 additions & 36 deletions auto_fp8/modeling.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
from typing import List, Optional, Tuple
from typing import List

import torch
from transformers import AutoModelForCausalLM
Expand Down Expand Up @@ -27,16 +27,6 @@ def __init__(
self.model, quantize_config.ignore_patterns
)

if quantize_config.kv_cache_quant_targets:
kv_cache_quant_layers = get_kv_cache_quant_layers(
self.model, quantize_config.kv_cache_quant_targets
)
if len(kv_cache_quant_layers) == 0:
raise ValueError(
f"Could not find any kv cache layers using kv_cache_quant_targets={quantize_config.kv_cache_quant_targets}, please fix your argument."
)
quantize_config.kv_cache_quant_layers = kv_cache_quant_layers

self.quantize_config = quantize_config

@classmethod
Expand Down Expand Up @@ -107,28 +97,26 @@ def skip(*args, **kwargs):

return cls(model, quantize_config)

def quantize(self, calibration_tokens: Optional[torch.Tensor] = None):
def quantize(self, calibration_tokens):
def _prepare_calibration_data(calibration_tokens):
if hasattr(calibration_tokens, "input_ids"):
return calibration_tokens.input_ids
return calibration_tokens

# Always quantize the weights as they do not require calibration data
quantize_weights(self.model, self.quantize_config)

if self.quantize_config.activation_scheme == "static":
assert (
calibration_tokens is not None
), "Calibration tokens required for activation quantization"


def _prepare_calibration_data(calibration_tokens):
if hasattr(calibration_tokens, "input_ids"):
return calibration_tokens.input_ids
return calibration_tokens

quantize_activations(
self.model,
self.quantize_config,
_prepare_calibration_data(calibration_tokens),
)

# import copy
# for layer in self.model.model.layers:
# layer.self_attn.kv_scale = copy.deepcopy(layer.self_attn.k_proj.input_scale)

def save_quantized(self, save_dir):
save_quantized_model(
self.model,
Expand All @@ -140,6 +128,9 @@ def save_quantized(self, save_dir):
def get_layers_to_ignore(model, ignore_patterns) -> List[str]:
ignored_layers = set()

# TODO: don't always ignore lm_head
ignore_patterns.append("re:.*lm_head")

for name, linear in model.named_modules():
if not isinstance(linear, torch.nn.Linear):
continue
Expand All @@ -157,17 +148,3 @@ def get_layers_to_ignore(model, ignore_patterns) -> List[str]:
ignored_layers.add(name)

return list(ignored_layers)


def get_kv_cache_quant_layers(model, kv_cache_quant_targets: Tuple[str]) -> List[str]:
kv_cache_quant_layers = []

for name, linear in model.named_modules():
if not isinstance(linear, torch.nn.Linear):
continue

for output_quant_target in kv_cache_quant_targets:
if name.endswith(output_quant_target):
kv_cache_quant_layers.append(name)

return kv_cache_quant_layers
156 changes: 47 additions & 109 deletions auto_fp8/quantize.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import gc
import re
from typing import Optional, Tuple
from typing import List, Tuple
import copy

import torch
Expand Down Expand Up @@ -61,22 +61,14 @@ def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]:
return qweight, scale


def static_per_tensor_quantize(tensor: torch.Tensor, inv_scale: float) -> torch.Tensor:
finfo = torch.finfo(torch.float8_e4m3fn)
qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
return qweight.to(torch.float8_e4m3fn)


def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype):
if A.numel() == 0:
# Deal with empty tensors (triggeted by empty MoE experts)
return torch.empty(size=(0, B.shape[0]), dtype=out_dtype, device=A.device)

# TODO: Disable native fp8 gemm for now, always just dequantize
# native_fp8_support = (
# torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
# )
native_fp8_support = False

native_fp8_support = (
torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
)
if native_fp8_support:
need_reshape = A.dim() == 3
if need_reshape:
Expand Down Expand Up @@ -106,24 +98,25 @@ def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype):
return output


# Class responsible for quantizing weights
class FP8DynamicLinear(torch.nn.Module):
class FP8StaticLinearQuantizer(torch.nn.Module):
def __init__(
self,
weight: torch.Tensor,
weight_scale: torch.Tensor,
bias: torch.nn.Parameter,
self, qweight: torch.Tensor, weight_scale: torch.Tensor, bias: torch.Tensor
):
super().__init__()
self.weight = torch.nn.Parameter(weight, requires_grad=False)
self.weight = torch.nn.Parameter(qweight, requires_grad=False)
self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
self.input_scale = None
self.bias = bias

def forward(self, x):
qinput, x_scale = per_tensor_quantize(x)
qinput, x_input_scale = per_tensor_quantize(x)
if self.input_scale is None:
self.input_scale = torch.nn.Parameter(x_input_scale)
elif x_input_scale > self.input_scale:
self.input_scale = torch.nn.Parameter(x_input_scale)
output = fp8_gemm(
A=qinput,
A_scale=x_scale,
A_scale=self.input_scale,
B=self.weight,
B_scale=self.weight_scale,
bias=self.bias,
Expand All @@ -132,29 +125,29 @@ def forward(self, x):
return output


# Module responsible for taking already quantized weights, and recording input scales (and possibly output scales) using an activation observer
class FP8StaticLinearQuantizer(torch.nn.Module):
class FP8StaticLinear(torch.nn.Module):
def __init__(
self,
weight: torch.Tensor,
qweight: torch.Tensor,
weight_scale: torch.Tensor,
bias: torch.nn.Parameter,
quantize_output: bool = False,
bias: torch.Tensor,
input_scale: float = 1.0,
):
super().__init__()
self.weight = torch.nn.Parameter(weight, requires_grad=False)
self.weight = torch.nn.Parameter(qweight, requires_grad=False)
self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
self.input_scale = torch.nn.Parameter(input_scale, requires_grad=False)
self.bias = bias
self.input_scale = None
self.output_scale = None
self.quantize_output = quantize_output

def per_tensor_quantize(
self, tensor: torch.Tensor, inv_scale: float
) -> torch.Tensor:
finfo = torch.finfo(torch.float8_e4m3fn)
qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
return qweight.to(torch.float8_e4m3fn)

def forward(self, x):
qinput, x_input_scale = per_tensor_quantize(x)
if self.input_scale is None:
self.input_scale = torch.nn.Parameter(x_input_scale, requires_grad=False)
elif x_input_scale > self.input_scale:
self.input_scale = torch.nn.Parameter(x_input_scale, requires_grad=False)
qinput = self.per_tensor_quantize(x, inv_scale=self.input_scale)
output = fp8_gemm(
A=qinput,
A_scale=self.input_scale,
Expand All @@ -163,51 +156,26 @@ def forward(self, x):
bias=self.bias,
out_dtype=x.dtype,
)

# Optionally, quantize output and record scale
if self.quantize_output:
qoutput, output_scale = per_tensor_quantize(output)
if self.output_scale is None:
self.output_scale = torch.nn.Parameter(output_scale, requires_grad=False)
elif output_scale > self.output_scale:
self.output_scale = torch.nn.Parameter(output_scale, requires_grad=False)
output = qoutput.to(output.dtype) * output_scale

return output


# Module responsible for representing the final checkpoint representation
class FP8StaticLinear(torch.nn.Module):
def __init__(
self,
weight: torch.nn.Parameter,
weight_scale: torch.nn.Parameter,
bias: torch.nn.Parameter,
input_scale: torch.nn.Parameter,
output_scale: Optional[torch.nn.Parameter] = None,
):
class FP8DynamicLinear(torch.nn.Module):
def __init__(self, qweight: torch.Tensor, scale: torch.Tensor, bias: torch.Tensor):
super().__init__()
self.weight = weight
self.weight_scale = weight_scale
self.weight = torch.nn.Parameter(qweight, requires_grad=False)
self.weight_scale = torch.nn.Parameter(scale, requires_grad=False)
self.bias = bias
self.input_scale = input_scale
self.output_scale = output_scale

def forward(self, x):
qinput = static_per_tensor_quantize(x, self.input_scale)
qinput, x_scale = per_tensor_quantize(x)
output = fp8_gemm(
A=qinput,
A_scale=self.input_scale,
A_scale=x_scale,
B=self.weight,
B_scale=self.weight_scale,
bias=self.bias,
out_dtype=x.dtype,
)

if self.output_scale:
qoutput = static_per_tensor_quantize(output, self.output_scale)
output = qoutput.to(output.dtype) * self.output_scale

return output


Expand All @@ -226,6 +194,7 @@ def replace_module(model: AutoModelForCausalLM, name: str, new_module: torch.nn.
def quantize_weights(
model: AutoModelForCausalLM,
quantize_config: BaseQuantizeConfig,
ignored_layers: List[str] = [],
):
named_modules = list(model.named_modules())
for name, linear in tqdm.tqdm(named_modules, desc="Quantizing weights"):
Expand All @@ -234,11 +203,9 @@ def quantize_weights(
or name in quantize_config.ignored_layers
):
continue
quant_weight, weight_scale = per_tensor_quantize(linear.weight)
quant_weight, quant_scale = per_tensor_quantize(linear.weight)
bias = copy.deepcopy(linear.bias) if linear.bias is not None else None
quant_linear = FP8DynamicLinear(
weight=quant_weight, weight_scale=weight_scale, bias=bias
)
quant_linear = FP8DynamicLinear(quant_weight, quant_scale, bias)
replace_module(model, name, quant_linear)
del linear.weight
del linear.bias
Expand All @@ -250,6 +217,7 @@ def quantize_activations(
model: AutoModelForCausalLM,
quantize_config: BaseQuantizeConfig,
calibration_tokens,
ignored_layers: List[str] = [],
):
# Replace weight quantizer with a dynamic activation quantizer observer
for name, dynamic_quant_linear in model.named_modules():
Expand All @@ -259,13 +227,9 @@ def quantize_activations(
):
continue
quantizer = FP8StaticLinearQuantizer(
weight=dynamic_quant_linear.weight,
weight_scale=dynamic_quant_linear.weight_scale,
bias=dynamic_quant_linear.bias,
quantize_output=(
hasattr(quantize_config, "kv_cache_quant_layers")
and name in quantize_config.kv_cache_quant_layers
),
dynamic_quant_linear.weight,
dynamic_quant_linear.weight_scale,
dynamic_quant_linear.bias,
)
replace_module(model, name, quantizer)
del dynamic_quant_linear
Expand All @@ -287,45 +251,21 @@ def quantize_activations(
):
continue
static_proj = FP8StaticLinear(
weight=quantizer.weight,
weight_scale=quantizer.weight_scale,
bias=quantizer.bias,
input_scale=quantizer.input_scale,
output_scale=quantizer.output_scale,
quantizer.weight,
quantizer.weight_scale,
quantizer.bias,
quantizer.input_scale,
)
replace_module(model, name, static_proj)
del quantizer
cleanup_memory()

# Post-process step for kv cache scales to take the k/v module
# `output_scale` parameters, take the max of them, and store them in
# the parent attention module as `kv_scale`
# NOTE: if we want to switch to the `output_scale` representation, we can simply remove this block
if hasattr(quantize_config, "kv_cache_quant_layers"):
# Assumes that list is ordered such that [layer0.k_proj, layer0.v_proj, layer1.k_proj, layer1.v_proj, ...]
# so we make a list of tuples [(layer0.k_proj, layer0.v_proj), (layer1.k_proj, layer1.v_proj), ...]
kv_proj_pairs = zip(*[iter(quantize_config.kv_cache_quant_layers)]*2)
for k_proj_name, v_proj_name in kv_proj_pairs:
parent_module_name = ".".join(k_proj_name.split(".")[:-1])
assert parent_module_name == ".".join(v_proj_name.split(".")[:-1])
parent_module = dict(model.named_modules())[parent_module_name]

k_proj = dict(model.named_modules())[k_proj_name]
v_proj = dict(model.named_modules())[v_proj_name]

kv_scale = max(k_proj.output_scale, v_proj.output_scale)
parent_module.kv_scale = torch.nn.Parameter(kv_scale, requires_grad=False)

# Remove output_scale from k_proj and v_proj
k_proj.output_scale = None
v_proj.output_scale = None
cleanup_memory()


def save_quantized_model(
model: AutoModelForCausalLM,
quant_config: BaseQuantizeConfig,
save_dir: str,
ignored_layers: List[str] = [],
):
print(model)
print(f"Saving the model to {save_dir}")
Expand All @@ -336,8 +276,6 @@ def save_quantized_model(
"ignored_layers": quant_config.ignored_layers,
}
}
if hasattr(quant_config, "kv_cache_quant_layers"):
static_q_dict["quantization_config"]["kv_cache_scheme"] = "static"
model.config.update(static_q_dict)
model.save_pretrained(save_dir)
tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path)
Expand Down
Loading

0 comments on commit 3662e0e

Please sign in to comment.