-
Notifications
You must be signed in to change notification settings - Fork 273
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
The advantage of multi-query attention (MQA) lies in both reducing the size of the KV cache and making self-attention computation more efficient. The current implementation only saves on KV cache size. This PR improves it further by not only reducing the computation cost, but also saving the per-layer KV cache memory. This becomes especially critical when dealing with very long contexts. For instance, if an LLM is processing a context length of 1 million tokens using the Character.ai architecture [1], there might be around 4 unique KV cache layers. Let’s assume there are 4 KV heads and 32 total attention heads, with a dim_per_head of 128. In the current implementation, each layer consumes significant memory for self-attention KV caching (using bfloat16): * Current (ASIS): 8GB (128 * 32 * 2 * 1M) * Optimized (TODO): 1GB (128 * 4 * 2 * 1M) [1] https://research.character.ai/optimizing-inference/ * Benchmark results: it saves memory and computation. tools/attention_benchmark.py on TPUv5p ASIS ----------------------------------------------------------------------------------------- Benchmark Time CPU Iterations HBM (over 95.74G) ----------------------------------------------------------------------------------------- MQABenchmark/2048/16/2/1024 1.42 ms 0.247 ms 2347 291.16M MQABenchmark/4096/16/2/1024 3.60 ms 0.277 ms 1257 322.95M MQABenchmark/4096/16/2/4096 47.3 ms 0.818 ms 139 4.25G MQABenchmark/4096/16/2/8192 869 ms 0.932 ms 140 48.00G This PR ----------------------------------------------------------------------------------------- Benchmark Time CPU Iterations HBM (over 95.74G) ----------------------------------------------------------------------------------------- MQABenchmark/2048/16/2/1024 1.16 ms 0.256 ms 2535 262.35M MQABenchmark/4096/16/2/1024 3.46 ms 0.294 ms 1114 266.88M MQABenchmark/4096/16/2/4096 24.8 ms 0.769 ms 137 4.04G MQABenchmark/4096/16/2/8192 860 ms 1.19 ms 136 48.00G
- Loading branch information
Showing
5 changed files
with
79 additions
and
66 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters