-
Notifications
You must be signed in to change notification settings - Fork 431
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove lr_scheduler requirement in lora_dpo_single_device #1991
base: main
Are you sure you want to change the base?
Remove lr_scheduler requirement in lora_dpo_single_device #1991
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1991
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for making this update, had a few comments on implementation
@@ -243,7 +243,7 @@ def setup(self, cfg: DictConfig) -> None: | |||
# Learning rate scheduler can only be set up after number of steps | |||
# has been computed | |||
self._lr_scheduler = self._setup_lr_scheduler( | |||
cfg_lr_scheduler=cfg.lr_scheduler, | |||
cfg_lr_scheduler=cfg.get("lr_scheduler", None), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
may be less indirection to just check directly if lr_scheduler exists and set to None here, instead of calling the setup method only to return None:
cfg_lr_scheduler = cfg.get("lr_scheduler", None)
self._lr_scheduler = self._setup_lr_scheduler(...) if cfg_lr_scheduler is not None else None
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes sense to me, Rafi, but upon second thought, i like the idea of handling everything inside of the setup_lr_scheduler, including the log_info. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, either works, no strong opinions
@@ -191,3 +191,49 @@ def test_save_and_load_merged_weights(self, tmpdir, monkeypatch): | |||
llama2_model.load_state_dict(sd) | |||
merged_ckpt_out = llama2_model(inputs) | |||
torch.testing.assert_close(baseline_out, merged_ckpt_out, rtol=1e-5, atol=1e-5) | |||
|
|||
@pytest.mark.integration_test | |||
def test_lr_scheduler_optional(self, tmpdir, monkeypatch): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a great test but it seems expensive just to check that the lr_scheduler is not created. Maybe you could just call the recipe's setup()
method with a toy model and check that self._lr_scheduler == None
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I initially wanted to go with this approach, but I do not think we are allowed to import anything in recipes
(i.e. LoRADPORecipeSingleDevice
) from tests. Specifically, running from recipes.lora_dpo_single_device import LoRADPORecipeSingleDevice
raises this error:
Lines 19 to 23 in 912af64
raise ModuleNotFoundError( | |
"The torchtune recipes directory isn't a package and you should not import anything from here. " | |
"Refer to our docs for detailed instructions on how to use recipes: " | |
"https://pytorch.org/torchtune/main/deep_dives/recipe_deepdive.html" | |
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am ok with not adding a test for it. Any thoughts @ebsmothers ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, let's just remove it
Thanks for the pr, @thomasjpfan! I left a couple of comments. After Rafi and Evan reply, i am good with merging it. PS: after the changes for lora_dpo_single_device are done, would you be willing to add it to the other recipes too? Its ok if the answer is no. We can do it as a follow up :) |
I'll be happy to add it to the other recipes as a followup. I started off with one recipe to see how much test converge is required. |
@thomasjpfan let's go ahead and remove the test and this is good to go! |
Just before we approve it, do you mind running it with wandb logger (or any other logger) and capturing the learning rate with/without the scheduler, as a sanity check? You can append this to your cli command:
|
Context
What is the purpose of this PR? Is it to
Towards #1751
For testing, I initially wanted to
monkeypatch
therecipes
module, but it was not designed to be imported from tests:torchtune/recipes/__init__.py
Lines 7 to 8 in 912af64
tune run
for one epoch.Changelog
What are the changes made in this PR?
lr_scheduler
requirement in lora_dpo_single_deviceTest plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example