Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support calibrating kv cache scales #17

Merged
merged 10 commits into from
Jun 18, 2024
Merged

Support calibrating kv cache scales #17

merged 10 commits into from
Jun 18, 2024

Conversation

mgoin
Copy link
Member

@mgoin mgoin commented Jun 13, 2024

Adds a kv_cache_quant_targets quant config argument that attaches output_scales to the specified Linear modules. This means we will end up with k_proj.output_scale and v_proj.output_scale after activation calibration. For the final checkpoint, we add a pass to take the maximum of k_proj.output_scale and v_proj.output_scale, and place the result in the parent of those modules (the Attention module) as a single kv_scale, which is needed to match the representation in vLLM.

Also includes a decent chunk of refactoring to allow for no examples to be passed in for weight quantization, renaming for clearer understanding of modules, making "re:.*lm_head" not a required ignored pattern but just a default, and disabling torch._scaled_mm for easier usage on CPU.

A new example is included to show how to enable this functionality

from datasets import load_dataset
from transformers import AutoTokenizer

from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig

pretrained_model_dir = "meta-llama/Meta-Llama-3-8B-Instruct"
quantized_model_dir = "Meta-Llama-3-8B-Instruct-FP8-KV"

tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token

ds = load_dataset("mgoin/ultrachat_2k", split="train_sft").select(range(512))
examples = [tokenizer.apply_chat_template(batch["messages"], tokenize=False) for batch in ds]
examples = tokenizer(examples, padding=True, truncation=True, return_tensors="pt").to("cuda")

quantize_config = BaseQuantizeConfig(
    quant_method="fp8",
    activation_scheme="static",
    ignore_patterns=["re:.*lm_head"],
    kv_cache_quant_targets=("k_proj", "v_proj"),
)

model = AutoFP8ForCausalLM.from_pretrained(pretrained_model_dir, quantize_config)
model.quantize(examples)
model.save_quantized(quantized_model_dir)

@mgoin mgoin force-pushed the support-kv-cache-scales branch from 96cc0b0 to 084feb8 Compare June 14, 2024 18:57
@mgoin mgoin marked this pull request as ready for review June 14, 2024 18:57
@mgoin mgoin mentioned this pull request Jun 18, 2024
@mgoin mgoin merged commit 0d40b99 into main Jun 18, 2024
4 checks passed
mgoin added a commit that referenced this pull request Jun 19, 2024
@mgoin mgoin linked an issue Jun 20, 2024 that may be closed by this pull request
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

FP8 KV cache support
1 participant