Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce the memory usage of logits from O(context_length) to O(1) #4688

Merged
merged 1 commit into from
Aug 23, 2024

Conversation

iseeyuan
Copy link
Contributor

@iseeyuan iseeyuan commented Aug 13, 2024

The logits size is big, with size [context_length x vocab_size]. But we always use the last (new) logits, because the model generates one new token in each Transformer inference.

This PR changes the transformer to return the logits of the last token only. In the runner code, we don't have to fetch the logits for the last token specifically, but directly use the output .

Test command:

python -m examples.models.llama2.export_llama --checkpoint /Users/myuan/data/llama/story110m/checkpoint.pt --params /Users/myuan/data/llama/story110m/params.json -kv --use_sdpa_with_kv_cache -X -qmode 8da4w --group_size 128 -d fp32 --max_seq_length 1024 --profile_memory

Before: 284 MB activation, with 262 MB on logits
After: 162 MB activation, with 0.128 MB on logits

Verified with llamma_runner, before and after it generates the same text with temperature=0.

Now the dominant memory usage would be KV cache.

TODO:

  • Improve KV cache memory usage using pf16 or quantization.
  • This PR only fixes logits. Further activation memory optimization with one token output.

Additional tests:
llava

python -m unittest examples.models.llava.test.test_llava -k test_prefill_logits
python -m unittest examples.models.llava.test.test_llava -k test_generated_output
python -m unittest examples.models.llava.test.test_llava -k test_llava_export

Copy link

pytorch-bot bot commented Aug 13, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/4688

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 Cancelled Job

As of commit be438eb with merge base d7c069f (image):

CANCELLED JOB - The following job was cancelled. Please retry:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 13, 2024
@iseeyuan iseeyuan changed the title [testing] Reduce the memory usage of logistics from O(context_length) to O(1) Reduce the memory usage of logistics from O(context_length) to O(1) Aug 13, 2024
@iseeyuan iseeyuan changed the title Reduce the memory usage of logistics from O(context_length) to O(1) Reduce the memory usage of logits from O(context_length) to O(1) Aug 13, 2024
@facebook-github-bot
Copy link
Contributor

@iseeyuan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

facebook-github-bot pushed a commit that referenced this pull request Aug 14, 2024
Summary:
The logits size is big, with size [context_length x vocab_size]. But we always use the last (new) logits, because the model generates one new token in each Transformer inference. 

This PR changes the transformer to return the logits of the last token only. In the runner code, we don't have to fetch the logits for the last token specifically, but directly use the output .

Test command:
```
python -m examples.models.llama2.export_llama --checkpoint /Users/myuan/data/llama/story110m/checkpoint.pt --params /Users/myuan/data/llama/story110m/params.json -kv --use_sdpa_with_kv_cache -X -qmode 8da4w --group_size 128 -d fp32 --max_seq_length 1024 --profile_memory
```
Before: 284 MB activation, with 262 MB on logits
After: 162 MB activation, with 0.128 MB on logits

Verified with llamma_runner, before and after it generates the same text with temperature=0. 

Now the dominant memory usage would be KV cache. 

TODO: 
- Improve KV cache memory usage using pf16 or quantization.
- This PR only fixes logits. Further activation memory optimization with one token output.


Differential Revision: D61246566
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D61246566

@larryliu0820
Copy link
Contributor

Can you fix llava CI job?

facebook-github-bot pushed a commit that referenced this pull request Aug 15, 2024
Summary:
The logits size is big, with size [context_length x vocab_size]. But we always use the last (new) logits, because the model generates one new token in each Transformer inference. 

This PR changes the transformer to return the logits of the last token only. In the runner code, we don't have to fetch the logits for the last token specifically, but directly use the output .

Test command:
```
python -m examples.models.llama2.export_llama --checkpoint /Users/myuan/data/llama/story110m/checkpoint.pt --params /Users/myuan/data/llama/story110m/params.json -kv --use_sdpa_with_kv_cache -X -qmode 8da4w --group_size 128 -d fp32 --max_seq_length 1024 --profile_memory
```
Before: 284 MB activation, with 262 MB on logits
After: 162 MB activation, with 0.128 MB on logits

Verified with llamma_runner, before and after it generates the same text with temperature=0. 

Now the dominant memory usage would be KV cache. 

TODO: 
- Improve KV cache memory usage using pf16 or quantization.
- This PR only fixes logits. Further activation memory optimization with one token output.


Differential Revision: D61246566

Pulled By: iseeyuan
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D61246566

facebook-github-bot pushed a commit that referenced this pull request Aug 16, 2024
Summary:
The logits size is big, with size [context_length x vocab_size]. But we always use the last (new) logits, because the model generates one new token in each Transformer inference. 

This PR changes the transformer to return the logits of the last token only. In the runner code, we don't have to fetch the logits for the last token specifically, but directly use the output .

Test command:
```
python -m examples.models.llama2.export_llama --checkpoint /Users/myuan/data/llama/story110m/checkpoint.pt --params /Users/myuan/data/llama/story110m/params.json -kv --use_sdpa_with_kv_cache -X -qmode 8da4w --group_size 128 -d fp32 --max_seq_length 1024 --profile_memory
```
Before: 284 MB activation, with 262 MB on logits
After: 162 MB activation, with 0.128 MB on logits

Verified with llamma_runner, before and after it generates the same text with temperature=0. 

Now the dominant memory usage would be KV cache. 

TODO: 
- Improve KV cache memory usage using pf16 or quantization.
- This PR only fixes logits. Further activation memory optimization with one token output.


Differential Revision: D61246566

Pulled By: iseeyuan
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D61246566

facebook-github-bot pushed a commit that referenced this pull request Aug 17, 2024
Summary:
The logits size is big, with size [context_length x vocab_size]. But we always use the last (new) logits, because the model generates one new token in each Transformer inference. 

This PR changes the transformer to return the logits of the last token only. In the runner code, we don't have to fetch the logits for the last token specifically, but directly use the output .

Test command:
```
python -m examples.models.llama2.export_llama --checkpoint /Users/myuan/data/llama/story110m/checkpoint.pt --params /Users/myuan/data/llama/story110m/params.json -kv --use_sdpa_with_kv_cache -X -qmode 8da4w --group_size 128 -d fp32 --max_seq_length 1024 --profile_memory
```
Before: 284 MB activation, with 262 MB on logits
After: 162 MB activation, with 0.128 MB on logits

Verified with llamma_runner, before and after it generates the same text with temperature=0. 

Now the dominant memory usage would be KV cache. 

TODO: 
- Improve KV cache memory usage using pf16 or quantization.
- This PR only fixes logits. Further activation memory optimization with one token output.


Differential Revision: D61246566

Pulled By: iseeyuan
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D61246566

facebook-github-bot pushed a commit that referenced this pull request Aug 18, 2024
Summary:
The logits size is big, with size [context_length x vocab_size]. But we always use the last (new) logits, because the model generates one new token in each Transformer inference. 

This PR changes the transformer to return the logits of the last token only. In the runner code, we don't have to fetch the logits for the last token specifically, but directly use the output .

Test command:
```
python -m examples.models.llama2.export_llama --checkpoint /Users/myuan/data/llama/story110m/checkpoint.pt --params /Users/myuan/data/llama/story110m/params.json -kv --use_sdpa_with_kv_cache -X -qmode 8da4w --group_size 128 -d fp32 --max_seq_length 1024 --profile_memory
```
Before: 284 MB activation, with 262 MB on logits
After: 162 MB activation, with 0.128 MB on logits

Verified with llamma_runner, before and after it generates the same text with temperature=0. 

Now the dominant memory usage would be KV cache. 

TODO: 
- Improve KV cache memory usage using pf16 or quantization.
- This PR only fixes logits. Further activation memory optimization with one token output.


Differential Revision: D61246566

Pulled By: iseeyuan
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D61246566

facebook-github-bot pushed a commit that referenced this pull request Aug 19, 2024
Summary:
The logits size is big, with size [context_length x vocab_size]. But we always use the last (new) logits, because the model generates one new token in each Transformer inference. 

This PR changes the transformer to return the logits of the last token only. In the runner code, we don't have to fetch the logits for the last token specifically, but directly use the output .

Test command:
```
python -m examples.models.llama2.export_llama --checkpoint /Users/myuan/data/llama/story110m/checkpoint.pt --params /Users/myuan/data/llama/story110m/params.json -kv --use_sdpa_with_kv_cache -X -qmode 8da4w --group_size 128 -d fp32 --max_seq_length 1024 --profile_memory
```
Before: 284 MB activation, with 262 MB on logits
After: 162 MB activation, with 0.128 MB on logits

Verified with llamma_runner, before and after it generates the same text with temperature=0. 

Now the dominant memory usage would be KV cache. 

TODO: 
- Improve KV cache memory usage using pf16 or quantization.
- This PR only fixes logits. Further activation memory optimization with one token output.


Differential Revision: D61246566

Pulled By: iseeyuan
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D61246566

facebook-github-bot pushed a commit that referenced this pull request Aug 19, 2024
Summary:
The logits size is big, with size [context_length x vocab_size]. But we always use the last (new) logits, because the model generates one new token in each Transformer inference. 

This PR changes the transformer to return the logits of the last token only. In the runner code, we don't have to fetch the logits for the last token specifically, but directly use the output .

Test command:
```
python -m examples.models.llama2.export_llama --checkpoint /Users/myuan/data/llama/story110m/checkpoint.pt --params /Users/myuan/data/llama/story110m/params.json -kv --use_sdpa_with_kv_cache -X -qmode 8da4w --group_size 128 -d fp32 --max_seq_length 1024 --profile_memory
```
Before: 284 MB activation, with 262 MB on logits
After: 162 MB activation, with 0.128 MB on logits

Verified with llamma_runner, before and after it generates the same text with temperature=0. 

Now the dominant memory usage would be KV cache. 

TODO: 
- Improve KV cache memory usage using pf16 or quantization.
- This PR only fixes logits. Further activation memory optimization with one token output.


Differential Revision: D61246566

Pulled By: iseeyuan
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D61246566

facebook-github-bot pushed a commit that referenced this pull request Aug 21, 2024
Summary:
The logits size is big, with size [context_length x vocab_size]. But we always use the last (new) logits, because the model generates one new token in each Transformer inference. 

This PR changes the transformer to return the logits of the last token only. In the runner code, we don't have to fetch the logits for the last token specifically, but directly use the output .

Test command:
```
python -m examples.models.llama2.export_llama --checkpoint /Users/myuan/data/llama/story110m/checkpoint.pt --params /Users/myuan/data/llama/story110m/params.json -kv --use_sdpa_with_kv_cache -X -qmode 8da4w --group_size 128 -d fp32 --max_seq_length 1024 --profile_memory
```
Before: 284 MB activation, with 262 MB on logits
After: 162 MB activation, with 0.128 MB on logits

Verified with llamma_runner, before and after it generates the same text with temperature=0. 

Now the dominant memory usage would be KV cache. 

TODO: 
- Improve KV cache memory usage using pf16 or quantization.
- This PR only fixes logits. Further activation memory optimization with one token output.


Reviewed By: larryliu0820

Differential Revision: D61246566

Pulled By: iseeyuan
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D61246566

facebook-github-bot pushed a commit that referenced this pull request Aug 21, 2024
Summary:
The logits size is big, with size [context_length x vocab_size]. But we always use the last (new) logits, because the model generates one new token in each Transformer inference. 

This PR changes the transformer to return the logits of the last token only. In the runner code, we don't have to fetch the logits for the last token specifically, but directly use the output .

Test command:
```
python -m examples.models.llama2.export_llama --checkpoint /Users/myuan/data/llama/story110m/checkpoint.pt --params /Users/myuan/data/llama/story110m/params.json -kv --use_sdpa_with_kv_cache -X -qmode 8da4w --group_size 128 -d fp32 --max_seq_length 1024 --profile_memory
```
Before: 284 MB activation, with 262 MB on logits
After: 162 MB activation, with 0.128 MB on logits

Verified with llamma_runner, before and after it generates the same text with temperature=0. 

Now the dominant memory usage would be KV cache. 

TODO: 
- Improve KV cache memory usage using pf16 or quantization.
- This PR only fixes logits. Further activation memory optimization with one token output.


Reviewed By: larryliu0820

Differential Revision: D61246566

Pulled By: iseeyuan
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D61246566

1 similar comment
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D61246566

iseeyuan pushed a commit that referenced this pull request Aug 21, 2024
Summary:
The logits size is big, with size [context_length x vocab_size]. But we always use the last (new) logits, because the model generates one new token in each Transformer inference.

This PR changes the transformer to return the logits of the last token only. In the runner code, we don't have to fetch the logits for the last token specifically, but directly use the output .

Test command:
```
python -m examples.models.llama2.export_llama --checkpoint /Users/myuan/data/llama/story110m/checkpoint.pt --params /Users/myuan/data/llama/story110m/params.json -kv --use_sdpa_with_kv_cache -X -qmode 8da4w --group_size 128 -d fp32 --max_seq_length 1024 --profile_memory
```
Before: 284 MB activation, with 262 MB on logits
After: 162 MB activation, with 0.128 MB on logits

Verified with llamma_runner, before and after it generates the same text with temperature=0.

Now the dominant memory usage would be KV cache.

TODO:
- Improve KV cache memory usage using pf16 or quantization.
- This PR only fixes logits. Further activation memory optimization with one token output.

Pull Request resolved: #4688

Reviewed By: larryliu0820

Differential Revision: D61246566

Pulled By: iseeyuan
Summary:
The logits size is big, with size [context_length x vocab_size]. But we always use the last (new) logits, because the model generates one new token in each Transformer inference. 

This PR changes the transformer to return the logits of the last token only. In the runner code, we don't have to fetch the logits for the last token specifically, but directly use the output .

Test command:
```
python -m examples.models.llama2.export_llama --checkpoint /Users/myuan/data/llama/story110m/checkpoint.pt --params /Users/myuan/data/llama/story110m/params.json -kv --use_sdpa_with_kv_cache -X -qmode 8da4w --group_size 128 -d fp32 --max_seq_length 1024 --profile_memory
```
Before: 284 MB activation, with 262 MB on logits
After: 162 MB activation, with 0.128 MB on logits

Verified with llamma_runner, before and after it generates the same text with temperature=0. 

Now the dominant memory usage would be KV cache. 

TODO: 
- Improve KV cache memory usage using pf16 or quantization.
- This PR only fixes logits. Further activation memory optimization with one token output.


Reviewed By: larryliu0820

Differential Revision: D61246566

Pulled By: iseeyuan
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D61246566

@facebook-github-bot facebook-github-bot merged commit 11e8ed3 into main Aug 23, 2024
89 of 92 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants