Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Adding overrides for max cache seq length #1449

Merged
merged 26 commits into from
Sep 16, 2024

Conversation

SalmanMohammadi
Copy link
Collaborator

@SalmanMohammadi SalmanMohammadi commented Aug 29, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

#1364

Changelog

This PR:

  • Adds support for overriding the maximum sequence length used when setting up KV-caches.
  • Adds support for correctly setting up caches for self-attention, cross-attention, and fusion layers by exposing encoder and decoder max_seq_len args. These arguments are exposed to the top-level transformer class (i.e. TransformerDecoder, DeepFusionModel, so that the API for all models remains the same model.setup_caches(bsz, dtype, encoder_max_seq_len, decoder_max_seq_len)
  • We also remove the use of input_pos to update and retrieve from the KV-cache. Instead, the KV cache tracks its own position.
image

Test plan

Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.)

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Example of docstring:


Example in our docs: https://pytorch.org/torchtune/main/tutorials/qat_finetune.html#applying-qat-to-llama3-models

  • I did not change any public API;
  • I have added an example to docs or docstrings;

Copy link

pytorch-bot bot commented Aug 29, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1449

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 3fc1135 with merge base bc2c013 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 29, 2024
@SalmanMohammadi SalmanMohammadi changed the title [WIP] Refactoring KV-cache setup, adding overrides for max cache seq length [WIP][RFC] Refactoring KV-cache setup, adding overrides for max cache seq length Aug 29, 2024
@codecov-commenter
Copy link

codecov-commenter commented Aug 29, 2024

Codecov Report

Attention: Patch coverage is 81.40704% with 37 lines in your changes missing coverage. Please review.

Project coverage is 72.90%. Comparing base (726abb0) to head (3fc1135).
Report is 7 commits behind head on main.

Files with missing lines Patch % Lines
torchtune/modules/transformer.py 55.88% 30 Missing ⚠️
torchtune/models/gemma/transformer.py 0.00% 6 Missing ⚠️
recipes/eleuther_eval.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1449      +/-   ##
==========================================
+ Coverage   70.72%   72.90%   +2.18%     
==========================================
  Files         288      289       +1     
  Lines       14213    14336     +123     
==========================================
+ Hits        10052    10452     +400     
+ Misses       4161     3884     -277     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@SalmanMohammadi SalmanMohammadi changed the title [WIP][RFC] Refactoring KV-cache setup, adding overrides for max cache seq length [RFC] Refactoring KV-cache setup, adding overrides for max cache seq length Sep 4, 2024
@SalmanMohammadi SalmanMohammadi marked this pull request as ready for review September 4, 2024 22:02
@SalmanMohammadi SalmanMohammadi changed the title [RFC] Refactoring KV-cache setup, adding overrides for max cache seq length [RFC] Adding overrides for max cache seq length Sep 4, 2024
@SalmanMohammadi
Copy link
Collaborator Author

SalmanMohammadi commented Sep 6, 2024

What's up with the eleuther eval tests? They're passing locally for me.

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall it looks reasonable to me, will leave it to @pbontrager or @joecummings for the final sign-off

torchtune/modules/model_fusion/_fusion.py Outdated Show resolved Hide resolved
torchtune/modules/model_fusion/_fusion.py Outdated Show resolved Hide resolved
torchtune/modules/transformer.py Outdated Show resolved Hide resolved
torchtune/modules/transformer.py Outdated Show resolved Hide resolved
torchtune/utils/_generation.py Outdated Show resolved Hide resolved
torchtune/utils/_generation.py Outdated Show resolved Hide resolved
@SalmanMohammadi
Copy link
Collaborator Author

Thanks so much for the reviews @ebsmothers. Will address once #1424 lands and things are merged.

Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No concerns, but pls address @ebsmothers comments.

@SalmanMohammadi SalmanMohammadi merged commit 1e9dc42 into pytorch:main Sep 16, 2024
17 checks passed
@SalmanMohammadi SalmanMohammadi deleted the setup_cache_refactor branch September 17, 2024 09:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants