Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix save adapter weights only #1764

Merged
merged 2 commits into from
Oct 8, 2024

Conversation

ebsmothers
Copy link
Contributor

@ebsmothers ebsmothers commented Oct 8, 2024

Fixes #1713

Summary

We should not be using get_adapter_params on models with activation checkpointing enabled. This is because get_adapter_params iterates over named_modules, not state_dict. This is a good thing from a memory perspective, but it also means that the state dict hooks that are usually triggered to unwrap activation checkpointing wrapper modules (or previously FSDP wrapper modules) do not get triggered. We can consider adding an option to get_adapter_params to do this in that utility cause this is a bit tricky. At the same time, that also feels hacky.

Actually this is really only an issue in our single-device recipe when save_adapter_weights_only=True -- otherwise we filter the state dict on the keys of self.adapter_params (this is correct because (a) self.adapter_params is defined before any AC wrapping and (b) we are using the state dict so the relevant hooks have fired).

But actually, we don't need to use get_adapter_params here at all: we already have references to all the adapter weights and the correct (non-AC-wrapped) keys in self.adapter_params as defined here. So we can just use that, and we can do it regardless of the value of save_adapter_weights_only (which makes the code a bit cleaner).

Test plan

Why didn't we catch this before? Well actually all of our existing recipes tests set enable_activation_checkpointing=False so we missed it. For that reason I've updated the test_recipe_state_on_resume tests to have enable_activation_checkpointing=True. Running those tests on main with the new values for enable_activation_checkpointing:

pytest tests/recipes/test_lora_finetune_single_device.py -m integration_test -k 'test_training_state_on_resume'
...
===== 1 failed, 1 passed, 9 deselected, 1 warning in 7.29s ======


pytest tests/recipes/test_lora_dpo_single_device.py -m integration_test -k 'test_training_st
ate_on_resume'
...
==== 1 failed, 1 passed, 1 deselected, 1 warning in 28.37s =======

If we instead run with the changes on this PR:

pytest tests/recipes/test_lora_finetune_distributed.py -m integration_test
...
========== 7 passed in 127.43s (0:02:07) ==========

pytest tests/recipes/test_lora_dpo_single_device.py -m integration_test
...
========== 3 passed, 2 warnings in 52.46s =======

pytest tests/recipes/test_lora_finetune_single_device.py -m integration_test
...
========== 11 passed, 2 warnings in 271.11s (0:04:31) ===========

Also, the command from #1713 that was failing before now passes:

tune run lora_finetune_single_device --config llama3_2/3B_lora_single_device max_steps_per_epoch=10 \
gradient_accumulation_steps=1 epochs=1 save_adapter_weights_only=True

Copy link

pytorch-bot bot commented Oct 8, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1764

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 8357be3 with merge base a8a64ec (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 8, 2024
@ebsmothers ebsmothers marked this pull request as ready for review October 8, 2024 04:28
@ebsmothers ebsmothers changed the title [WIP] Fix save adapter weights only Fix save adapter weights only Oct 8, 2024
Copy link
Contributor

@RdoubleA RdoubleA left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

solid fix

@ebsmothers ebsmothers merged commit 82f8b77 into pytorch:main Oct 8, 2024
17 checks passed
@ebsmothers ebsmothers deleted the adapter-weight-save-fix branch October 8, 2024 18:12
@joecummings joecummings mentioned this pull request Oct 9, 2024
34 tasks
mori360 pushed a commit to mori360/torchtune that referenced this pull request Oct 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Training is stuck at saving checkpoint for Llama3.2
3 participants