-
Notifications
You must be signed in to change notification settings - Fork 430
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Compile model+loss for LoRA single device recipe #1296
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1296
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 6013bef with merge base c1fabb9 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
loss = self._loss_fn(logits, labels) | ||
# free logits otherwise it peaks backward memory | ||
del logits | ||
batch = {k: v.to(self._device) for k, v in batch.items()} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My concern is that the batch may contain keys that may not be necessary in the gpu. Can we keep the .to(device) like it was before? Or did you have some other motivation for it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried putting the old .to(device)
code within self._loss_step()
, but torch.compile()
shows some warnings. Should be harmless though. I will double-check it again.
Another option is to keep the old .to(device)
outside self._loss_step()
. Then, it is a bit more clunky to pass tokens
, labels
, mask
, input_pos
to self._loss_step()
instead of just batch
(or we can just modify the batch
object in-place). Lmk which one you prefer, then I will just change to that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think it should be fine to keep as you did. I dont think that our batches output extra terms, so that scenario I raised would be an exception.
Before i approve it, would it make sense to add it to the other recipes, so they are all aligned? For example, should we test it in distributed? We probably want to avoid having different versions of loss_step for each recipe. out of curiosity: Did you notice any major difference in compiling time? |
I think we can add this to other recipes in separate PR, since each requires good testing to confirm no regression and expected speed up. For example, even for this recipe (single device QLoRA), took me 10min+ for each mini run with Lllama3-8B (well most of the time is spent on quantization due to #1246 and compilation).
I don't have a multi-gpu setup, so I can't run and test it myself 😅.
I think the point is that
Not on my machine. I also checked it myself. Didn't put a timer, but this wandb graph suggests they have the same compile time (power usage goes up at the same time -> training starts at the same time) Another question. Realised that I deleted this line of yours
Since now we compile model+loss together, I don't think this will make a difference for |
sounds good! gonna let the tests run and approve the PR. Thanks for addressing my questions |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1296 +/- ##
==========================================
+ Coverage 68.31% 70.76% +2.44%
==========================================
Files 255 258 +3
Lines 11796 11880 +84
==========================================
+ Hits 8059 8407 +348
+ Misses 3737 3473 -264 ☔ View full report in Codecov by Sentry. |
Context
What is the purpose of this PR? Is it to
Please link to any issues this PR addresses.
Address #1228
To limit the scope of this PR, I will only update the LoRA single device recipe. We can keep #1228 open and discuss how to do model+loss compile for other recipes.
Changelog
What are the changes made in this PR?
Compile model+loss together for LoRA single device recipe. This has shown substantial speed improvements.
Test plan
Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.)
pre-commit install
)pytest tests
pytest tests -m integration_test
Run the following command on 4070Ti SUPER (16GB VRAM)
15% improvement in tok/s on average, generally lower memory consumption.