From 86b293b060441b289964fa19ef7dc2e087f28460 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Wed, 4 Sep 2024 19:48:53 -0700 Subject: [PATCH] adding kv cache quantization to READMEs Summary: see READMEs Test Plan: Reviewers: Subscribers: Tasks: Tags: --- README.md | 6 ++++++ torchao/_models/llama/README.md | 21 +++++++++++++++++++-- torchao/quantization/README.md | 8 +++++++- 3 files changed, 32 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 7257ff27d..40bc3805b 100644 --- a/README.md +++ b/README.md @@ -48,6 +48,12 @@ model = torchao.autoquant(torch.compile(model, mode='max-autotune')) We also provide a developer facing API so you can implement your own quantization algorithms so please use the excellent [HQQ](https://github.com/pytorch/ao/tree/main/torchao/prototype/hqq) algorithm as a motivating example. +### KV Cache Quantization + +We've added kv cache quantization and other features in order to enable long context length (and necessarily memory efficient) inference. + +In practice these features alongside int4 weight only quantization allow us to **reduce peak memory by ~55%**, meaning we can Llama3.1-8B inference with a **130k context length with only 18.9 GB of peak memory.** More details can be found [here](torchao/_models/llama/README.md) + ### Quantization Aware Training Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization Aware Training (QAT) to overcome this limitation. In collaboration with Torchtune, we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering **96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext** for Llama3 compared to post-training quantization (PTQ). And we've provided a full recipe [here](https://pytorch.org/blog/quantization-aware-training/) diff --git a/torchao/_models/llama/README.md b/torchao/_models/llama/README.md index cfd9c353e..466901e74 100644 --- a/torchao/_models/llama/README.md +++ b/torchao/_models/llama/README.md @@ -8,5 +8,22 @@ and follow the steps to gain access. Then from the torchao root directory use `huggingface-cli login` and follow the steps to login, then `sh ./scripts/prepare.sh` to download and convert the model weights -once done you can execute benchmarks from the torchao/_models/llama dir with `sh benchmarks.sh`. You can perform and benchmarking -directly using `generate.py`. +once done you can execute benchmarks from the torchao/_models/llama dir with `sh benchmarks.sh`. You can perform and benchmarking or evaluation +directly using `generate.py` or `eval.py`. + +## KV Cache Quantization - Memory Efficient Inference +We've added some features to `model.py` compared to the original gpt-fast implementation in order to enable long context length (and necessarily memory efficient) inference. Specifically we've added kv_cache quantization and a linear_causal_mask implementation which are **able to reduce memory usage by 50-60%** at long context lengths. + +In practice these features alongside int4 weight only quantization allow us to do Llama3.1-8B inference with a **130k context length with only 18.9 GB of peak memory.** + +You can check it out yourself with `generate.py`, these features exist as a proof of concept and technical demonstration of the techniques though we're working to figure out a way to release them in a general way. Until then feel free to copy these features into your own models. The details and a full explanation can be found in this [PR](https://github.com/pytorch/ao/pull/738) + +To see how these techniques scale generally we've run `generate.py` with subsets of these features for different context lengths on an A100 GPU. You can find commands to reproduce these numbers in `benchmarks.sh` + +| context length (tokens) | normal peak (GB) | kv_quant peak (GB) | kv quant+linear_causal_mask peak (GB) | +|-------------------------|------------------|--------------------|---------------------------------------| +| 8192 | 17.86 | 17.52 | 17.47 | +| 16384 | 19.81 | 18.75 | 18.48 | +| 32768 | 23.83 | 21.72 | 20.64 | +| 65536 | 33.5 | 29.54 | 25.24 | +| 131072 | 59.27 | 52.62 | 34.18 | diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index 6112abd02..26abe5182 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -25,7 +25,7 @@ note: Int8 dynamic quantization works best on compute bound models like [SAM](ht For int4 we make heavy use of [tinygemm](https://github.com/pytorch/ao/blob/cb3bd8c674f2123af232a0231b5e38ddafa756a8/torchao/dtypes/aqt.py#L526) of `torch.ops.aten._weight_int4pack_mm` to bitpack into a layout optimized for tensor cores -And a quick crash course on inference quantization to help parse the above table. Int4 quantization is an ambiguous term because there's the dtype in which a layer is represented and then the dtype in which the computation is done. For example, if you're using Weight-Only (wo) int4 quantization that means that the layer will be upcasted to a larger dtype like fp16 so an int4 matrix multiplication is defined as `F.linear(input, weight.to(input.dtype))`. Dynamic quantization (DQ) primarily targets activations, enabling on-the-fly quantization from higher precision formats like bf16 to lower precision formats such as int8. This process, when supported by hardware, allows for direct computation, such as performing `F.linear(input, weight)`. Naive quantization algorithms are also notoriously sensitive to outliers so we also typically set a group size that applies a scale factor per group of 64 elements in the case of `int4wo64`. +And a quick crash course on inference quantization to help parse the above table. Int4 quantization is an ambiguous term because there's the dtype in which a layer is represented and then the dtype in which the computation is done. For example, if you're using Weight-Only (wo) int4 quantization that means that the layer will be upcasted to a larger dtype like fp16 so an int4 matrix multiplication is defined as `F.linear(input, weight.to(input.dtype))`. Dynamic quantization (DQ) primarily targets activations, enabling on-the-fly quantization from higher precision formats like bf16 to lower precision formats such as int8. This process, when supported by hardware, allows for direct computation, such as performing `F.linear(input, weight)`. Naive quantization algorithms are also notoriously sensitive to outliers so we also typically set a group size that applies a scale factor per group of 64 elements in the case of `int4wo-64`. ## Autoquantization @@ -233,6 +233,12 @@ change_linear_weights_to_int4_woqtensors(model) Note: The quantization error incurred by applying int4 quantization to your model can be fairly significant, so using external techniques like GPTQ may be necessary to obtain a usable model. +### KV Cache Quantization + +We've added kv cache quantization and other features in order to enable long context length (and necessarily memory efficient) inference. + +In practice these features alongside int4 weight only quantization allow us to **reduce peak memory by ~55%**, meaning we can Llama3.1-8B inference with a **130k context length with only 18.9 GB of peak memory.** More details can be found [here](torchao/_models/llama/README.md) + ## (To be moved to prototype) A16W4 WeightOnly Quantization with GPTQ ```python