-
Notifications
You must be signed in to change notification settings - Fork 367
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DRAFT] Factor out core SDPA #1561
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1561
Note: Links to docs will display an error until the docs builds have been completed. ❌ 4 New Failures, 5 Cancelled JobsAs of commit f506e22 with merge base 9a863c8 (): NEW FAILURES - The following jobs have failed:
CANCELLED JOBS - The following jobs were cancelled. Please retry:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Hey @dvorjackz can you share the motivation for this PR? (I know it's still a draft, just wanna understand what the goal is for when I do eventually review it) |
@ebsmothers added context to the pr description! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
left some comments
acffd0a
to
e61cf56
Compare
left another nit. Looks good to me. Make sure it is exportable. And want to hear thoughts from Tune folks |
hey @dvorjackz , thanks for the PR and the extra context! It makes sense to me to make it swappable, but this is a relatively large change. I am not sure how this will interact with compile + multimodal. There are also other ongoing PRs that are modifying kv_cache. We may need to align on the design a bit. I want to minimize the amount of work you have to do, but if your version is working, adding testing will make it much easier to approve. (e.g. you could run supervised training for vision model for 50 steps with / without the PR, to compare that everything works fine, and then run a generation task.) But we should align on the design first. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delay in getting to this one. Tbh I am not sure I like the way we are refactoring the multi-head attention here. I think there is a pretty canonical MHA flow that folks in OSS are used to (see e.g. litgpt's CausalSelfAttention or transformers's LlamaAttention) and this would be diverging from that in a meaningful way. I am OK with such divergence if it makes the code easier to understand, but in this case we are actually adding another layer of abstraction that's not very clear (why do we need a separate module to handle a couple reshapes + SDPA? why do we call it SDPA when it in fact contains a call to nn.functional.scaled_dot_product_attention
, which we then call self._attn_fn
?)
Anyways this is not to be too harsh on this PR cause I do understand the motivation from the ET perspective. Just so that I am more well-informed here, can you share the optimized SDPA op from ET so I can look at it as a reference? Then maybe we can brainstorm about a less intrusive way we can achieve this.
Context
This PR factors out the optimizable portions of SDPA (namely the kv cache update, the transpose, the expand, and the actual sdpa). This allows a module containing optimized implementations of the above functionalities to easily be swapped in with the new module via source transformation.
Atm, ET has an optimized SDPA op that does all of the above (kv cache update, transpose, expand, sdpa) that we are hoping to swap in pre-export. cc @kimishpatel.
Proof of correctness
Before change
After change
We see essentially the exact same loss (slight difference since only trained for one epoch). Interestingly, we also squeeze out ~9% more tokens per gpu after this refactor.
Changelog
Factor out inference-optimizable portions of SDPA
Test plan
[Pending] Went through a quick export process as sanity check, will do more extensive correctness checking
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
No public API changes