Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DoRA fixes #2139

Merged
merged 7 commits into from
Dec 11, 2024
Merged

DoRA fixes #2139

merged 7 commits into from
Dec 11, 2024

Conversation

ebsmothers
Copy link
Contributor

@ebsmothers ebsmothers commented Dec 10, 2024

Co-authored by: @mirceamironenco

Fixes #2129

This PR fixes our broken DoRA distributed training. It also makes some quality-of-life improvements to how we initialize DoRALinear and (in a small way) LoRALinear.

We do the following:

  1. Define a to_empty method on both DoRALinear and LoRALinear. This consolidates the move of all the potentially-still-on-meta-device params onto a real device in a single place (lora_a.weight and lora_b.weight for LoRA, and additionally magnitude for DoRA) and matches the contract of nn.Module. For DoRALinear's magnitude we create a new parameter with torch.empty_like and manually set requires_grad, then use the torch.utils.swap_tensors API to ensure self.magnitude winds up on a real device.

  2. For DoRA, we modify the initialize_dora_magnitude method to (a) check that base weight and LoRA weights are not on meta device, (b) calculate ||W + (alpha/rank)BA||, and (c) copy this value into self.magnitude. Since this should always be called after to_empty, we now know that self.magnitude

BC note: From the perspective of someone importing from our library to their own custom recipe, the changes to LoRALinear's contract with respect to to_empty are backwards compatible: calling lora_linear.to_empty(device) (new version) vs lora_linear.lora_a.to_empty(device); lora_linear.lora_b.to_empty(device) (old version) are identical. For DoRALinear, the recipe will be broken without the update to the new version.. but the recipe was broken anyways, so I claim this is OK.

Other changes:

  • Deprecate (but don't yet delete) load_dora_magnitudes, which is redundant.
  • Add unit tests for proper calculation of DoRA magnitude on single-device and distributed w/ meta device init

Test plan

Unit tests

pytest tests/torchtune/modules/peft/test_dora.py
...
======== 12 passed in 11.73s ========

E2E tests

Single-device recipe

Compare the following three cases:

  1. This PR
  2. Resume intermediate checkpoint on this PR
  3. Baseline (run on main)

Loss curves for all three are the same:

Screenshot 2024-12-11 at 11 22 24 AM

Command for (1):

tune run lora_finetune_single_device --config llama3/8B_dora_single_device max_steps_per_epoch=100 epochs=2 \
metric_logger=torchtune.training.metric_logging.WandBLogger metric_logger.project=test-dora-fixes \
metric_logger.name=dora-single-new

Commands for (2):

# First delete final checkpoint to automatically resume from intermediate checkpoint
rm -r /tmp/torchtune/llama3_8B/dora_single_device/epoch_1

tune run lora_finetune_single_device --config llama3/8B_dora_single_device max_steps_per_epoch=100 epochs=2 \
metric_logger=torchtune.training.metric_logging.WandBLogger metric_logger.project=test-dora-fixes \
metric_logger.name=dora-single-new-resumed resume_from_checkpoint=True

Distributed recipe

  1. This PR
  2. Resume intermediate checkpoint on this PR

(Note that (3) from the single-device recipe is not applicable, as this is broken on main). Loss curves are the same after resuming from checkpoint:

Screenshot 2024-12-11 at 11 25 25 AM

Command for (1):

tune run --nproc_per_node 2 lora_finetune_distributed --config llama3/8B_dora max_steps_per_epoch=100 epochs=2 \
metric_logger=torchtune.training.metric_logging.WandBLogger metric_logger.project=test-dora-fixes \
metric_logger.name=dora-distributed-new

Commands for (2):

# First delete final checkpoint to automatically resume from intermediate checkpoint
rm -r /tmp/torchtune/llama3_8B/dora/epoch_1

tune run --nproc_per_node 2 lora_finetune_distributed --config llama3/8B_dora max_steps_per_epoch=100 epochs=2 \
metric_logger=torchtune.training.metric_logging.WandBLogger metric_logger.project=test-dora-fixes \
metric_logger.name=dora-distributed-new-resume resume_from_checkpoint=True

Copy link

pytorch-bot bot commented Dec 10, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2139

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 031d89b with merge base 06a8379 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 10, 2024

magnitude = nn.Parameter(
torch.empty_like(self.magnitude, device=self.lora_a.weight.device),
requires_grad=self.magnitude.requires_grad,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When would this be False?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Offhand I can’t think of a case but imo better to just be explicit here

requires_grad=self.magnitude.requires_grad,
)
magnitude = self._get_weight_norm(base_weight, lora_weight)
torch.utils.swap_tensors(self.magnitude, magnitude)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not 100% sure here, but will the original torch.empty(out_dim) still be floating around after this? This probably isn't an issue?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I was wondering about this too, need to check myself. Offhand not sure about the implications for refcounting with this utility. But magnitude is just embed_dim number of elements so even if still hanging around it’s pretty small

@ebsmothers ebsmothers changed the title [WIP] DoRA fixes DoRA fixes Dec 11, 2024
@codecov-commenter
Copy link

codecov-commenter commented Dec 11, 2024

Codecov Report

Attention: Patch coverage is 88.18182% with 13 lines in your changes missing coverage. Please review.

Project coverage is 67.75%. Comparing base (06a8379) to head (c7e937f).
Report is 6 commits behind head on main.

Files with missing lines Patch % Lines
recipes/lora_dpo_distributed.py 0.00% 4 Missing ⚠️
recipes/knowledge_distillation_single_device.py 0.00% 3 Missing ⚠️
torchtune/modules/peft/lora.py 50.00% 2 Missing ⚠️
recipes/knowledge_distillation_distributed.py 0.00% 1 Missing ⚠️
recipes/lora_finetune_distributed.py 0.00% 1 Missing ⚠️
recipes/qat_lora_finetune_distributed.py 0.00% 1 Missing ⚠️
tests/recipes/test_lora_finetune_distributed.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##            main    #2139       +/-   ##
==========================================
+ Coverage   9.33%   67.75%   +58.41%     
==========================================
  Files        289      334       +45     
  Lines      16956    19281     +2325     
==========================================
+ Hits        1583    13064    +11481     
+ Misses     15373     6217     -9156     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@ebsmothers ebsmothers marked this pull request as draft December 11, 2024 00:37
@ebsmothers ebsmothers marked this pull request as ready for review December 11, 2024 01:07
@ebsmothers ebsmothers merged commit 9cfa288 into pytorch:main Dec 11, 2024
17 checks passed
@SalmanMohammadi SalmanMohammadi mentioned this pull request Dec 16, 2024
17 tasks
rahul-sarvam added a commit to sarvamai/torchtune that referenced this pull request Dec 18, 2024
* Llama 3.3 70B (pytorch#2124)

* Llama 3.3 readme updates (pytorch#2125)

* update configs (pytorch#2107)

Co-authored-by: Felipe Mello <felipemello@fb.com>

* Reduce logging output for distributed KD (pytorch#2120)

* Support Early Exit Loss and/or Layer Dropout (pytorch#1076)

Co-authored-by: ebsmothers <ebs@meta.com>

* Update checkpointing directory (pytorch#2074)

Co-authored-by: Felipe Mello <felipemello@fb.com>
Co-authored-by: vancoyendall <vancoykendall@gmail.com>

* pass correct arg (pytorch#2127)

Co-authored-by: Felipe Mello <felipemello@fb.com>

* update configs (pytorch#2128)

Co-authored-by: Felipe Mello <felipemello@fb.com>

* fix qat_lora_test (pytorch#2131)

Co-authored-by: Felipe Mello <felipemello@fb.com>

* guard ckpt imports (pytorch#2133)

Co-authored-by: Felipe Mello <felipemello@fb.com>

* [bug fix] add parents=True (pytorch#2136)

Co-authored-by: Felipe Mello <felipemello@fb.com>

* [bug fix] re-add model (pytorch#2135)

Co-authored-by: Felipe Mello <felipemello@fb.com>

* Update save sizes into GiB (pytorch#2143)

* [bug fix] remove config download when source is kaggle (pytorch#2144)

Co-authored-by: Felipe Mello <felipemello@fb.com>

* [fix] remove "with_suffix" (pytorch#2146)

Co-authored-by: Felipe Mello <felipemello@fb.com>

* DoRA fixes (pytorch#2139)



Co-authored-by: Mircea Mironenco <5738815+mirceamironenco@users.noreply.github.com>

* [Fix] Llama 3.2 Vision decoder_trainable flag fixed (pytorch#2150)

* Small readme, config updates (pytorch#2157)

* Using `FormattedCheckpointFiles` in configs (pytorch#2147)

* Move ``get_world_size_and_rank`` to utils (pytorch#2155)

* Faster intermediate checkpoints with DCP async save in TorchTune (pytorch#2006)

Co-authored-by: Saurabh Mishra <msaurabh@fb.com>

* torchdata integration - multi-dataset and streaming support (pytorch#1929)

* Allow higher version of lm-eval (pytorch#2165)

* Using `FormattedCheckpointFiles` in configs... round 2 (pytorch#2167)

* [EZ] Fix set_torch_num_threads in multi-node. (pytorch#2164)

---------

Co-authored-by: Philip Bontrager <pbontrager@gmail.com>
Co-authored-by: ebsmothers <ebs@meta.com>
Co-authored-by: Felipe Mello <fmellomascarenhas@gmail.com>
Co-authored-by: Felipe Mello <felipemello@fb.com>
Co-authored-by: Joe Cummings <jrcummings27@gmail.com>
Co-authored-by: Mostafa Elhoushi <m.elhoushi@ieee.org>
Co-authored-by: vancoyendall <vancoykendall@gmail.com>
Co-authored-by: Mircea Mironenco <5738815+mirceamironenco@users.noreply.github.com>
Co-authored-by: salman <salman.mohammadi@outlook.com>
Co-authored-by: Saurabh Mishra <msaurabh@meta.com>
Co-authored-by: Saurabh Mishra <msaurabh@fb.com>
Co-authored-by: Andrew Ho <andrew.kenneth.ho@gmail.com>
Co-authored-by: Eugen Hotaj <eugen_hotaj_91@hotmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Distributed DoRA training is broken
5 participants