Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compile model+loss for LoRA single device recipe #1296

Merged
merged 6 commits into from
Aug 10, 2024

Conversation

gau-nernst
Copy link
Contributor

@gau-nernst gau-nernst commented Aug 9, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Please link to any issues this PR addresses.

Address #1228

To limit the scope of this PR, I will only update the LoRA single device recipe. We can keep #1228 open and discuss how to do model+loss compile for other recipes.

Changelog

What are the changes made in this PR?

Compile model+loss together for LoRA single device recipe. This has shown substantial speed improvements.

Test plan

Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.)

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

Run the following command on 4070Ti SUPER (16GB VRAM)

tune run lora_finetune_single_device --config llama3/8B_qlora_single_device compile=True metric_logger._component_=torchtune.utils.metric_logging.WandBLogger batch_size=4 gradient_accumulation_steps=1 log_peak_memory_stats=True

image

15% improvement in tok/s on average, generally lower memory consumption.

Copy link

pytorch-bot bot commented Aug 9, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1296

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 6013bef with merge base c1fabb9 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 9, 2024
loss = self._loss_fn(logits, labels)
# free logits otherwise it peaks backward memory
del logits
batch = {k: v.to(self._device) for k, v in batch.items()}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My concern is that the batch may contain keys that may not be necessary in the gpu. Can we keep the .to(device) like it was before? Or did you have some other motivation for it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried putting the old .to(device) code within self._loss_step(), but torch.compile() shows some warnings. Should be harmless though. I will double-check it again.

Another option is to keep the old .to(device) outside self._loss_step(). Then, it is a bit more clunky to pass tokens, labels, mask, input_pos to self._loss_step() instead of just batch (or we can just modify the batch object in-place). Lmk which one you prefer, then I will just change to that.

Copy link
Contributor

@felipemello1 felipemello1 Aug 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think it should be fine to keep as you did. I dont think that our batches output extra terms, so that scenario I raised would be an exception.

@gau-nernst gau-nernst marked this pull request as ready for review August 9, 2024 17:21
@felipemello1
Copy link
Contributor

Before i approve it, would it make sense to add it to the other recipes, so they are all aligned? For example, should we test it in distributed? We probably want to avoid having different versions of loss_step for each recipe.

out of curiosity: Did you notice any major difference in compiling time?

@gau-nernst
Copy link
Contributor Author

would it make sense to add it to the other recipes, so they are all aligned?

I think we can add this to other recipes in separate PR, since each requires good testing to confirm no regression and expected speed up. For example, even for this recipe (single device QLoRA), took me 10min+ for each mini run with Lllama3-8B (well most of the time is spent on quantization due to #1246 and compilation).

should we test it in distributed

I don't have a multi-gpu setup, so I can't run and test it myself 😅.

We probably want to avoid having different versions of loss_step for each recipe.

I think the point is that self._loss_step() will be different for each recipe? so that each recipe can have their custom logic. But yea at least among the ones that trains on autogressive modelling loss (not DPO/PPO), their self._loss_step() should be the same.

Did you notice any major difference in compiling time?

Not on my machine. I also checked it myself. Didn't put a timer, but this wandb graph suggests they have the same compile time (power usage goes up at the same time -> training starts at the same time)

image

Another question. Realised that I deleted this line of yours

Since now we compile model+loss together, I don't think this will make a difference for compile=True anymore. But probably still has an impact in eager mode (compile=False). I will add it back. Is it ok with you?

@felipemello1
Copy link
Contributor

sounds good! gonna let the tests run and approve the PR. Thanks for addressing my questions

@codecov-commenter
Copy link

Codecov Report

Attention: Patch coverage is 0% with 16 lines in your changes missing coverage. Please review.

Project coverage is 70.76%. Comparing base (f6ddfcc) to head (6013bef).
Report is 2 commits behind head on main.

Files Patch % Lines
recipes/lora_finetune_single_device.py 0.00% 16 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1296      +/-   ##
==========================================
+ Coverage   68.31%   70.76%   +2.44%     
==========================================
  Files         255      258       +3     
  Lines       11796    11880      +84     
==========================================
+ Hits         8059     8407     +348     
+ Misses       3737     3473     -264     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@felipemello1 felipemello1 merged commit 00bbd53 into pytorch:main Aug 10, 2024
20 checks passed
@gau-nernst gau-nernst deleted the compile_full branch August 10, 2024 13:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants