Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: Disabling adapter works with modules_to_save #736

Conversation

BenjaminBossan
Copy link
Member

Resolves #493

For LoRA and IA³, there was a bug that even even using the disable_adapter context, if the module was listed in modules_to_save, the updated weights would be used instead of the original weights. This meant that disable_adapter would not return the same results as the base model without adaptation. This PR fixes the issue and provides a test.

Note: I tried to adjust AdaLoRA too, since it seemed that the same reasoning should apply there. However, I think that AdaLoRA does not really support disabling adapters at all. E.g. there is no disable_adapter_layers method. Therefore, AdaLoRA was not changed.

Resolves huggingface#493

For LoRA and IA³, there was a bug that even even using the
disable_adapter context, if the module was listed in modules_to_save,
the updated weights would be used instead of the original weights. This
meant that disable_adapter would not return the same results as the base
model without adaptation. This PR fixes the issue and provides a test.

Note: I tried to adjust AdaLoRA too, since it seemed that the same
reasoning should apply there. However, I think that AdaLoRA does not
really support disabling adapters at all. E.g. there is no
disable_adapter_layers method. Therefore, AdaLoRA was not changed.
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jul 20, 2023

The documentation is not available anymore as the PR was closed or merged.

@@ -303,6 +303,8 @@ def _set_adapter_layers(self, enabled=True):
for module in self.model.modules():
if isinstance(module, IA3Layer):
module.disable_adapters = False if enabled else True
elif isinstance(module, ModulesToSaveWrapper):
module.disable_adapter = False if enabled else True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would be a bit more intuitive to use the same name disable_adapters with a "s" for both IA3Layer and ModulesToSaveWrapper, for homogeneity.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise the PR looks good to me, as far as I can tell with my limited experience of peft.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for taking a look.

Good point about disable_adapter. I was going with singular since it's only a single adapter, but I can see your argument too. Let's wait what the other reviewers prefer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, let's have disable_adapters for homogeneity.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much @BenjaminBossan for the fix! 🤗

@@ -303,6 +303,8 @@ def _set_adapter_layers(self, enabled=True):
for module in self.model.modules():
if isinstance(module, IA3Layer):
module.disable_adapters = False if enabled else True
elif isinstance(module, ModulesToSaveWrapper):
module.disable_adapter = False if enabled else True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, let's have disable_adapters for homogeneity.

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @BenjaminBossan!

@BenjaminBossan BenjaminBossan merged commit b15c185 into huggingface:main Jul 24, 2023
@BenjaminBossan BenjaminBossan deleted the fix-disable-adapter-modules-to-save branch July 24, 2023 11:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

'LoraModel.disable_adapter_layers' not causing the model to use the original modules
4 participants