-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FIX: setting requires_grad on adapter layers #905
FIX: setting requires_grad on adapter layers #905
Conversation
This is an alternative to huggingface#900, resolves huggingface#899. Description Currently, we don't handle setting requires_grad on adapter layers really well. The main issue is that it can be set to True on adapter parameters that are not being used, e.g. the original_module in ModulesToSaveWrapper or inactive adapters in LoRA. Normally, this is not a big issue, except maybe if we want to correctly count the number of trainable parameters. However, when training with DistributedDataParallel, this results in errors, as PyTorch thinks that all parameters with requires_grad=True should participate in the loss computation, but those mentioned parameters don't. For that reason, training with DDP currently fails when using modules_to_save or multiple adapters. Implementation This turned out to be more complicated than I initially thought. The logic for setting requires_grad is all over the place, it was hard to encapsulate the logic and I only succeeded partially. As is, this PR is more complex than the one it tries to supersede, huggingface#900, but it is also "more correct". Tests were added to check whether requires_grad is set correctly. There are (so far) no tests for whether DDP indeed works, they could be added with multi-GPU. I did, however, test an early stage of this PR with DDP and setting requires_grad correctly will indeed fix the DDP error. DONE/TODO - [x] ModulesToSaveWrapper - [x] LoRA - [ ] IA³ - [ ] AdaLora Since some tuners are not implemented yet, tests are expected to fail. Check the new tests at the bottom of test_custom.py, those should pass.
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @BenjaminBossan for fixing this major bug when using DDP/Multiple Adapters with PEFT. LGTM! 🤗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a mile @BenjaminBossan !
This is an alternative to #900, resolves #899.
Thanks @passaglia for figuring out the underlying issue.
Description
Currently, we don't handle setting
requires_grad
on adapter layers really well. The main issue is that it can be set toTrue
on adapter parameters that are not being used, e.g. theoriginal_module
inModulesToSaveWrapper
or inactive adapters in LoRA.Normally, this is not a big issue, except maybe if we want to correctly count the number of trainable parameters. However, when training with
DistributedDataParallel
, this results in errors, as PyTorch thinks that all parameters withrequires_grad=True
should participate in the loss computation, but those mentioned parameters don't. For that reason, training with DDP currently errors when usingmodules_to_save
or multiple adapters.Implementation
This turned out to be more complicated than I initially thought. The logic for setting
requires_grad
is all over the place, it was hard to encapsulate the logic and I only succeeded partially. As is, this PR is more complex than the one it tries to supersede, #900, but it is also "more correct".Tests were added to check whether
requires_grad
is set correctly. There are (so far) no tests for whether DDP indeed works, they could be added with multi-GPU. I did, however, test an early stage of this PR with DDP and settingrequires_grad
correctly will indeed fix the DDP error.DONE/TODO