Skip to content

Commit

Permalink
adding kv cache quantization to READMEs
Browse files Browse the repository at this point in the history
Summary: see READMEs

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
HDCharles committed Sep 5, 2024
1 parent 1c488e8 commit 86b293b
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 3 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ model = torchao.autoquant(torch.compile(model, mode='max-autotune'))

We also provide a developer facing API so you can implement your own quantization algorithms so please use the excellent [HQQ](https://github.com/pytorch/ao/tree/main/torchao/prototype/hqq) algorithm as a motivating example.

### KV Cache Quantization

We've added kv cache quantization and other features in order to enable long context length (and necessarily memory efficient) inference.

In practice these features alongside int4 weight only quantization allow us to **reduce peak memory by ~55%**, meaning we can Llama3.1-8B inference with a **130k context length with only 18.9 GB of peak memory.** More details can be found [here](torchao/_models/llama/README.md)

### Quantization Aware Training

Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization Aware Training (QAT) to overcome this limitation. In collaboration with Torchtune, we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering **96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext** for Llama3 compared to post-training quantization (PTQ). And we've provided a full recipe [here](https://pytorch.org/blog/quantization-aware-training/)
Expand Down
21 changes: 19 additions & 2 deletions torchao/_models/llama/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,22 @@ and follow the steps to gain access.
Then from the torchao root directory use `huggingface-cli login` and follow the steps to login, then `sh ./scripts/prepare.sh` to
download and convert the model weights

once done you can execute benchmarks from the torchao/_models/llama dir with `sh benchmarks.sh`. You can perform and benchmarking
directly using `generate.py`.
once done you can execute benchmarks from the torchao/_models/llama dir with `sh benchmarks.sh`. You can perform and benchmarking or evaluation
directly using `generate.py` or `eval.py`.

## KV Cache Quantization - Memory Efficient Inference
We've added some features to `model.py` compared to the original gpt-fast implementation in order to enable long context length (and necessarily memory efficient) inference. Specifically we've added kv_cache quantization and a linear_causal_mask implementation which are **able to reduce memory usage by 50-60%** at long context lengths.

In practice these features alongside int4 weight only quantization allow us to do Llama3.1-8B inference with a **130k context length with only 18.9 GB of peak memory.**

You can check it out yourself with `generate.py`, these features exist as a proof of concept and technical demonstration of the techniques though we're working to figure out a way to release them in a general way. Until then feel free to copy these features into your own models. The details and a full explanation can be found in this [PR](https://github.com/pytorch/ao/pull/738)

To see how these techniques scale generally we've run `generate.py` with subsets of these features for different context lengths on an A100 GPU. You can find commands to reproduce these numbers in `benchmarks.sh`

| context length (tokens) | normal peak (GB) | kv_quant peak (GB) | kv quant+linear_causal_mask peak (GB) |
|-------------------------|------------------|--------------------|---------------------------------------|
| 8192 | 17.86 | 17.52 | 17.47 |
| 16384 | 19.81 | 18.75 | 18.48 |
| 32768 | 23.83 | 21.72 | 20.64 |
| 65536 | 33.5 | 29.54 | 25.24 |
| 131072 | 59.27 | 52.62 | 34.18 |
8 changes: 7 additions & 1 deletion torchao/quantization/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ note: Int8 dynamic quantization works best on compute bound models like [SAM](ht

For int4 we make heavy use of [tinygemm](https://github.com/pytorch/ao/blob/cb3bd8c674f2123af232a0231b5e38ddafa756a8/torchao/dtypes/aqt.py#L526) of `torch.ops.aten._weight_int4pack_mm` to bitpack into a layout optimized for tensor cores

And a quick crash course on inference quantization to help parse the above table. Int4 quantization is an ambiguous term because there's the dtype in which a layer is represented and then the dtype in which the computation is done. For example, if you're using Weight-Only (wo) int4 quantization that means that the layer will be upcasted to a larger dtype like fp16 so an int4 matrix multiplication is defined as `F.linear(input, weight.to(input.dtype))`. Dynamic quantization (DQ) primarily targets activations, enabling on-the-fly quantization from higher precision formats like bf16 to lower precision formats such as int8. This process, when supported by hardware, allows for direct computation, such as performing `F.linear(input, weight)`. Naive quantization algorithms are also notoriously sensitive to outliers so we also typically set a group size that applies a scale factor per group of 64 elements in the case of `int4wo64`.
And a quick crash course on inference quantization to help parse the above table. Int4 quantization is an ambiguous term because there's the dtype in which a layer is represented and then the dtype in which the computation is done. For example, if you're using Weight-Only (wo) int4 quantization that means that the layer will be upcasted to a larger dtype like fp16 so an int4 matrix multiplication is defined as `F.linear(input, weight.to(input.dtype))`. Dynamic quantization (DQ) primarily targets activations, enabling on-the-fly quantization from higher precision formats like bf16 to lower precision formats such as int8. This process, when supported by hardware, allows for direct computation, such as performing `F.linear(input, weight)`. Naive quantization algorithms are also notoriously sensitive to outliers so we also typically set a group size that applies a scale factor per group of 64 elements in the case of `int4wo-64`.

## Autoquantization

Expand Down Expand Up @@ -233,6 +233,12 @@ change_linear_weights_to_int4_woqtensors(model)

Note: The quantization error incurred by applying int4 quantization to your model can be fairly significant, so using external techniques like GPTQ may be necessary to obtain a usable model.

### KV Cache Quantization

We've added kv cache quantization and other features in order to enable long context length (and necessarily memory efficient) inference.

In practice these features alongside int4 weight only quantization allow us to **reduce peak memory by ~55%**, meaning we can Llama3.1-8B inference with a **130k context length with only 18.9 GB of peak memory.** More details can be found [here](torchao/_models/llama/README.md)

## (To be moved to prototype) A16W4 WeightOnly Quantization with GPTQ

```python
Expand Down

0 comments on commit 86b293b

Please sign in to comment.