-
Notifications
You must be signed in to change notification settings - Fork 169
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Llama3.1 and kv_cache quantization #738
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/738
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit c5e4dcb with merge base 37276d6 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Mostly looking good!
|
Summary: TODO: finish kv_cache testing generate memory_trace Added the 3.1 frequency rescaling and model definitions testing is ongoing Test Plan: python eval.py --checkpoint_path $../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --compile wikitext: {'word_perplexity,none': 7.441690325135099, 'word_perplexity_stderr,none': 'N/A', 'byte_perplexity,none': 1.4554823564993407, 'byte_perplexity_stderr,none': 'N/A', 'bits_per_byte,none': 0.541497351075118, 'bits_per_byte_stderr,none': 'N/A', 'alias': 'wikitext'} python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --max_new_tokens 16384 python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization --max_new_tokens 16384 python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --max_new_tokens 32768 python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization --max_new_tokens 32768 Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
3f712df
to
08d598c
Compare
|
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
This feels good to merge to me, fwiw @iseeyuan and @felipemello1 have also noticed a large difference by reducing the memory requirements of logits pytorch/executorch#4688 from O(context length) to O(1) Also remind me if you also quantized the model? (seems like no?) I'm trying to see if we can hit a 24GB VRAM budget or whether we need to explore int4 kv quantization. It'd be pretty sick do to a full llama 8b inference on a 128K context length, that should for example be enough to fit the entire AO code repo Also mind adding the top line VRAM requirements at the top, the line chart doesnt have even ranges on the y-axis (log scale?) so a bit hard to eyeball |
Could the mask be generated via a broadcasting trick (arange broadcasted and compared to another arange broadcasted differently) to alleviate the need for ones and then tril? Or not in this context? Otherwise, does FlexAttention allow to avoid materialization of such masks and compute the masking directly during the attention? (I thought that flash attention supported such materialization-free causal masks too...) |
Great work on the memory optimizations! Have you measured any impact on model accuracy or perplexity, with this method? |
this PR has support for llama 3.1 and some improvements to kv_cache quantization and general peak memory performance for llama
high level, we can now do inference with 130k context length in 18.9 GB peak memory if we apply kv cache quantization, linear causal mask and int4 weight-only quantization
summary of changes
Change to quantized kv_cache init
The first change is avoiding creating of the full precision kv_cache, previously we would initialize the kv_cache and then convert it to the quantized form as seen in this memory profile:
those horizontal lines from ~16.1 GB to 16.6GB are the normal kv_cache and you can see them being deallocated on the right side of the image as the quantized kv_cache's are instantiated. This created an unnecessary increase in peak memory any time the initialization is the peak (which was the case for very long context lengths).
Change to causal mask
This is a memory profile for 32k context length without kv_cache quantization or any other changes, compare to one with kv_cache quantization
those horizontal bands that run from 16GB to 20.5 GB on the top image and 18.5 on the bottom, are the kv_cache. With quantization its 2 GB smaller which shows the technique is performing as expected, however there is a large blue (top) or (green) blob (with a spike on the left side) that appears in the memory profile, this is the causal mask.
Normally the causal mask is handled by creating a (token length x token length) tensor of ones, then creating a copy that is lower triangular and taking slices from it throughout the model runs. Notice the sharp peak right at the start, this occurs because in order to copy a tensor of ones into a lower triangular matrix requires you to hold 2 instances of this in memory for a moment, thereby doubling its impact in addition to taking up O(context_length^2) memory. The doubling issue was solved by creating the causal mask before the kv_cache, if done like that, the momentary doubling spike doesn't affect the peak memory since the kv_cache will be higher than the spike.
Although the earlier instantiation of the causal mask helps (red blob now), it is still taking up a ton of space, especially at even higher context lengths, which is eating into the gains we expect from kv_cache quantization. Why do we need to actually store the causal mask though? A slice of the causal mask is essentually just a sequence of n ones in a row and then
context_length-n zeros in a row where n is the current token being generated. Each slice differs from the next only by a single value. We can just store the slice and update it each iteration instead. Result:
tests:
see benchmarks.sh
the 18.9 GB number came from
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --cache_size 131072 --kv_cache_quantization --linear_causal_mask --quantization int4wo-64