Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] Factor out core SDPA #1561

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

dvorjackz
Copy link
Contributor

@dvorjackz dvorjackz commented Sep 12, 2024

Context

This PR factors out the optimizable portions of SDPA (namely the kv cache update, the transpose, the expand, and the actual sdpa). This allows a module containing optimized implementations of the above functionalities to easily be swapped in with the new module via source transformation.

Atm, ET has an optimized SDPA op that does all of the above (kv cache update, transpose, expand, sdpa) that we are hoping to swap in pre-export. cc @kimishpatel.

Proof of correctness

Before change

$ tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device max_steps_per_epoch=25  epochs=1 metric_logger=torchtune.training.metric_logging.WandBLogger log_peak_memory_stats=True
.
.
.
1|25|Loss: 1.487076759338379: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [11:27<00:00, 27.49s/it]
wandb:                                                                                
wandb: 
wandb: Run history:
wandb:               global_step ▁▁▂▂▂▂▃▃▃▄▄▄▅▅▅▅▆▆▆▇▇▇▇██
wandb:                      loss ▆▆▇▆▆▆▇▇▅▇▅▅█▅▆▆▅▅▄▃▃▃▂▂▁
wandb:                        lr ▁▁▂▂▂▂▃▃▃▄▄▄▄▅▅▅▆▆▆▇▇▇▇██
wandb:        peak_memory_active ▅▆▇█▆█▇▆▄██▆▅▆▁▆▆▇█▄▇▇▃▅▇
wandb:         peak_memory_alloc ▅▆▇█▆█▇▆▄██▆▅▆▁▆▆▇█▄▇▇▃▅▇
wandb:      peak_memory_reserved ▁▁▂██████████████████████
wandb: tokens_per_second_per_gpu ▃▄▃▃▄▃▂▂▆▄▅█▁▅▂▂▂▄▄▃▃▄▃▂▃
wandb: 
wandb: Run summary:
wandb:               global_step 25
wandb:                      loss 1.48708
wandb:                        lr 7e-05
wandb:        peak_memory_active 16.20987
wandb:         peak_memory_alloc 16.20987
wandb:      peak_memory_reserved 18.77148
wandb: tokens_per_second_per_gpu 1313.99555
wandb: 
wandb: 🚀 View run flowing-lion-18 at: https://wandb.ai/dvorjackz-meta/torchtune/runs/imx4iros
wandb: ⭐️ View project at: https://wandb.ai/dvorjackz-meta/torchtune
wandb: Synced 5 W&B file(s), 0 media file(s), 2 artifact file(s) and 1 other file(s)
wandb: Find logs at: /tmp/wandb/run-20240920_154125-imx4iros/logs

After change

We see essentially the exact same loss (slight difference since only trained for one epoch). Interestingly, we also squeeze out ~9% more tokens per gpu after this refactor.

# tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device max_steps_per_epoch=25  epochs=1 metric_logger=torchtune.training.metric_logging.WandBLogger log_peak_memory_stats=True
.
.
.
1|25|Loss: 1.480817198753357: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [11:25<00:00, 27.41s/it]
wandb:                                                                                
wandb: 
wandb: Run history:
wandb:               global_step ▁▁▂▂▂▂▃▃▃▄▄▄▅▅▅▅▆▆▆▇▇▇▇██
wandb:                      loss ▆▆▇▇▆▆▇▇▅▇▅▅█▅▆▆▅▅▄▃▃▂▂▂▁
wandb:                        lr ▁▁▂▂▂▂▃▃▃▄▄▄▄▅▅▅▆▆▆▇▇▇▇██
wandb:        peak_memory_active ▅▆▇█▆█▇▆▄██▆▅▆▁▆▆▇█▄▇▇▃▅▇
wandb:         peak_memory_alloc ▅▆▇█▆█▇▆▄██▆▅▆▁▆▆▇█▄▇▇▃▅▇
wandb:      peak_memory_reserved ▁▁▂██████████████████████
wandb: tokens_per_second_per_gpu ▄▆▆▇▇▃▃▁▆▂▅█▂▄▃▅▃▃▃▃▄▄▃▃▆
wandb: 
wandb: Run summary:
wandb:               global_step 25
wandb:                      loss 1.48082
wandb:                        lr 7e-05
wandb:        peak_memory_active 16.20987
wandb:         peak_memory_alloc 16.20987
wandb:      peak_memory_reserved 18.77148
wandb: tokens_per_second_per_gpu 1434.81084
wandb: 
wandb: 🚀 View run fresh-puddle-17 at: https://wandb.ai/dvorjackz-meta/torchtune/runs/ehh7osgb
wandb: ⭐️ View project at: https://wandb.ai/dvorjackz-meta/torchtune
wandb: Synced 5 W&B file(s), 0 media file(s), 2 artifact file(s) and 1 other file(s)
wandb: Find logs at: /tmp/wandb/run-20240920_152752-ehh7osgb/logs

Changelog

Factor out inference-optimizable portions of SDPA

Test plan

[Pending] Went through a quick export process as sanity check, will do more extensive correctness checking

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

No public API changes

Copy link

pytorch-bot bot commented Sep 12, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1561

Note: Links to docs will display an error until the docs builds have been completed.

❌ 4 New Failures, 5 Cancelled Jobs

As of commit f506e22 with merge base 9a863c8 (image):

NEW FAILURES - The following jobs have failed:

CANCELLED JOBS - The following jobs were cancelled. Please retry:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 12, 2024
@ebsmothers
Copy link
Contributor

Hey @dvorjackz can you share the motivation for this PR? (I know it's still a draft, just wanna understand what the goal is for when I do eventually review it)

@dvorjackz
Copy link
Contributor Author

@ebsmothers added context to the pr description!

Copy link

@kimishpatel kimishpatel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left some comments

@kimishpatel
Copy link

left another nit. Looks good to me. Make sure it is exportable. And want to hear thoughts from Tune folks

@felipemello1
Copy link
Contributor

hey @dvorjackz , thanks for the PR and the extra context! It makes sense to me to make it swappable, but this is a relatively large change. I am not sure how this will interact with compile + multimodal. There are also other ongoing PRs that are modifying kv_cache. We may need to align on the design a bit.

I want to minimize the amount of work you have to do, but if your version is working, adding testing will make it much easier to approve. (e.g. you could run supervised training for vision model for 50 steps with / without the PR, to compare that everything works fine, and then run a generation task.) But we should align on the design first.

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay in getting to this one. Tbh I am not sure I like the way we are refactoring the multi-head attention here. I think there is a pretty canonical MHA flow that folks in OSS are used to (see e.g. litgpt's CausalSelfAttention or transformers's LlamaAttention) and this would be diverging from that in a meaningful way. I am OK with such divergence if it makes the code easier to understand, but in this case we are actually adding another layer of abstraction that's not very clear (why do we need a separate module to handle a couple reshapes + SDPA? why do we call it SDPA when it in fact contains a call to nn.functional.scaled_dot_product_attention, which we then call self._attn_fn?)

Anyways this is not to be too harsh on this PR cause I do understand the motivation from the ET perspective. Just so that I am more well-informed here, can you share the optimized SDPA op from ET so I can look at it as a reference? Then maybe we can brainstorm about a less intrusive way we can achieve this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants