-
Notifications
You must be signed in to change notification settings - Fork 523
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DoRA fixes #2139
DoRA fixes #2139
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2139
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 031d89b with merge base 06a8379 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
torchtune/modules/peft/dora.py
Outdated
|
||
magnitude = nn.Parameter( | ||
torch.empty_like(self.magnitude, device=self.lora_a.weight.device), | ||
requires_grad=self.magnitude.requires_grad, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When would this be False
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Offhand I can’t think of a case but imo better to just be explicit here
torchtune/modules/peft/dora.py
Outdated
requires_grad=self.magnitude.requires_grad, | ||
) | ||
magnitude = self._get_weight_norm(base_weight, lora_weight) | ||
torch.utils.swap_tensors(self.magnitude, magnitude) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not 100% sure here, but will the original torch.empty(out_dim)
still be floating around after this? This probably isn't an issue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I was wondering about this too, need to check myself. Offhand not sure about the implications for refcounting with this utility. But magnitude is just embed_dim number of elements so even if still hanging around it’s pretty small
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2139 +/- ##
==========================================
+ Coverage 9.33% 67.75% +58.41%
==========================================
Files 289 334 +45
Lines 16956 19281 +2325
==========================================
+ Hits 1583 13064 +11481
+ Misses 15373 6217 -9156 ☔ View full report in Codecov by Sentry. |
* Llama 3.3 70B (pytorch#2124) * Llama 3.3 readme updates (pytorch#2125) * update configs (pytorch#2107) Co-authored-by: Felipe Mello <felipemello@fb.com> * Reduce logging output for distributed KD (pytorch#2120) * Support Early Exit Loss and/or Layer Dropout (pytorch#1076) Co-authored-by: ebsmothers <ebs@meta.com> * Update checkpointing directory (pytorch#2074) Co-authored-by: Felipe Mello <felipemello@fb.com> Co-authored-by: vancoyendall <vancoykendall@gmail.com> * pass correct arg (pytorch#2127) Co-authored-by: Felipe Mello <felipemello@fb.com> * update configs (pytorch#2128) Co-authored-by: Felipe Mello <felipemello@fb.com> * fix qat_lora_test (pytorch#2131) Co-authored-by: Felipe Mello <felipemello@fb.com> * guard ckpt imports (pytorch#2133) Co-authored-by: Felipe Mello <felipemello@fb.com> * [bug fix] add parents=True (pytorch#2136) Co-authored-by: Felipe Mello <felipemello@fb.com> * [bug fix] re-add model (pytorch#2135) Co-authored-by: Felipe Mello <felipemello@fb.com> * Update save sizes into GiB (pytorch#2143) * [bug fix] remove config download when source is kaggle (pytorch#2144) Co-authored-by: Felipe Mello <felipemello@fb.com> * [fix] remove "with_suffix" (pytorch#2146) Co-authored-by: Felipe Mello <felipemello@fb.com> * DoRA fixes (pytorch#2139) Co-authored-by: Mircea Mironenco <5738815+mirceamironenco@users.noreply.github.com> * [Fix] Llama 3.2 Vision decoder_trainable flag fixed (pytorch#2150) * Small readme, config updates (pytorch#2157) * Using `FormattedCheckpointFiles` in configs (pytorch#2147) * Move ``get_world_size_and_rank`` to utils (pytorch#2155) * Faster intermediate checkpoints with DCP async save in TorchTune (pytorch#2006) Co-authored-by: Saurabh Mishra <msaurabh@fb.com> * torchdata integration - multi-dataset and streaming support (pytorch#1929) * Allow higher version of lm-eval (pytorch#2165) * Using `FormattedCheckpointFiles` in configs... round 2 (pytorch#2167) * [EZ] Fix set_torch_num_threads in multi-node. (pytorch#2164) --------- Co-authored-by: Philip Bontrager <pbontrager@gmail.com> Co-authored-by: ebsmothers <ebs@meta.com> Co-authored-by: Felipe Mello <fmellomascarenhas@gmail.com> Co-authored-by: Felipe Mello <felipemello@fb.com> Co-authored-by: Joe Cummings <jrcummings27@gmail.com> Co-authored-by: Mostafa Elhoushi <m.elhoushi@ieee.org> Co-authored-by: vancoyendall <vancoykendall@gmail.com> Co-authored-by: Mircea Mironenco <5738815+mirceamironenco@users.noreply.github.com> Co-authored-by: salman <salman.mohammadi@outlook.com> Co-authored-by: Saurabh Mishra <msaurabh@meta.com> Co-authored-by: Saurabh Mishra <msaurabh@fb.com> Co-authored-by: Andrew Ho <andrew.kenneth.ho@gmail.com> Co-authored-by: Eugen Hotaj <eugen_hotaj_91@hotmail.com>
Co-authored by: @mirceamironenco
Fixes #2129
This PR fixes our broken DoRA distributed training. It also makes some quality-of-life improvements to how we initialize
DoRALinear
and (in a small way)LoRALinear
.We do the following:
Define a
to_empty
method on bothDoRALinear
andLoRALinear
. This consolidates the move of all the potentially-still-on-meta-device params onto a real device in a single place (lora_a.weight
andlora_b.weight
for LoRA, and additionallymagnitude
for DoRA) and matches the contract of nn.Module. ForDoRALinear
's magnitude we create a new parameter withtorch.empty_like
and manually setrequires_grad
, then use thetorch.utils.swap_tensors
API to ensureself.magnitude
winds up on a real device.For DoRA, we modify the
initialize_dora_magnitude
method to (a) check that base weight and LoRA weights are not on meta device, (b) calculate ||W + (alpha/rank)BA||, and (c) copy this value intoself.magnitude
. Since this should always be called afterto_empty
, we now know thatself.magnitude
BC note: From the perspective of someone importing from our library to their own custom recipe, the changes to
LoRALinear
's contract with respect toto_empty
are backwards compatible: callinglora_linear.to_empty(device)
(new version) vslora_linear.lora_a.to_empty(device); lora_linear.lora_b.to_empty(device)
(old version) are identical. ForDoRALinear
, the recipe will be broken without the update to the new version.. but the recipe was broken anyways, so I claim this is OK.Other changes:
load_dora_magnitudes
, which is redundant.Test plan
Unit tests
E2E tests
Single-device recipe
Compare the following three cases:
Loss curves for all three are the same:
Command for (1):
Commands for (2):
Distributed recipe
(Note that (3) from the single-device recipe is not applicable, as this is broken on main). Loss curves are the same after resuming from checkpoint:
Command for (1):
Commands for (2):