Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PEFT] Allow PEFT model dict to be loaded #25721

Merged
merged 13 commits into from
Sep 15, 2023

Conversation

patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Aug 24, 2023

In order to allow peft to be leveraged in diffusers without breaking changes we need to allow loading adapters directly from a loaded state_dict. The reason is that in diffusers we currently store LoRA checkpoints in a format that is different to the PEFT format so we cannot just pass the model_id. This PR allows the user to manually pass a loaded PEFT model checkpoint as well as a PEFT configuration, thus circumventing the need to pass a model id.

In pseudo code, the integration of transformers + PEFT in diffusers should then look as follows the "load_lora" function of diffusers.

def load_lora_into_text_encoder(cls, state_dict, network_alphas, text_encoder, prefix=None, lora_scale=1.0):
    peft_state_dict, peft_config = convert_to_peft_format(state_dict, ...)  # <- this function will take care of all the remapping necessary for the different formats
    text_encoder.load_adapter(peft_state_dict, peft_config=peft_config)

Note, there might be more changes we have to do to PEFT, transformers' PEFT integration to be sure that everything works as expected. E.g. I'm not yet sure how to pass network_alphas etc... to PEFT to make sure we get 1-to-1 the same result.

cc @younesbelkada @sayakpaul @BenjaminBossan @pacman100

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR.

I'm not very knowledgeable about diffusers, so will let others comment on the overall solution. My comments are just minor issues, not blockers.

@@ -28,6 +29,10 @@
from accelerate.utils import get_balanced_memory, infer_auto_device_map


if is_torch_available():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this necessary?

@@ -59,14 +64,15 @@ class PeftAdapterMixin:

def load_adapter(
self,
peft_model_id: str,
peft_model_id: Union[str, Dict[str, "torch.Tensor"]],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we choose a different name, now that a state dict can be passed? Alternatively, we could add another (optional) argument to pass the state dict, make peft_model_id optional, and do a check that exactly one of the two should be passed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes makes sense to me! peft_model_id_or_state_dict sounds great - was just wondering about breaking change here. But think this function is not yet in a release so should probs be fine to change no?

@@ -75,7 +81,7 @@ def load_adapter(
Requires peft as a backend to load the adapter weights.

Args:
peft_model_id (`str`):
peft_model_id (`str` or dictionary of `torch.Tensor`):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The description below should also be adjusted.

adapter_name: Optional[str] = None,
revision: Optional[str] = None,
token: Optional[str] = None,
device_map: Optional[str] = "auto",
max_memory: Optional[str] = None,
offload_folder: Optional[str] = None,
offload_index: Optional[int] = None,
peft_config: Dict[str, Any] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
peft_config: Dict[str, Any] = None,
peft_config: Optional[Dict[str, Any]] = None,

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Aug 24, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! agreed also with @BenjaminBossan 's comments, I think we should maybe add an extra check and raise a proper error, for the naming, what about peft_model_id_or_state_dict (maybe that's too long) ?

raise ValueError(
f"adapter model file not found in {peft_model_id}. Make sure you are passing the correct path to the "
"adapter model."
if peft_config is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if peft_config is None:
if peft_config is None and isinstance(peft_model_id, str):

and add a check below that if peft_model_id is not a state dict, raise an error

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @patrickvonplaten for adding the support for passing state dict and config, LGTM!

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again !

@younesbelkada younesbelkada changed the title [WIP][PEFT] Allow PEFT model dict to be loaded [PEFT] Allow PEFT model dict to be loaded Sep 15, 2023
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have one question regarding the peft_model_id being a dictionary, otherwise looks good

Comment on lines 86 to 88
peft_model_id (`str` or dictionary of `torch.Tensor`):
The identifier of the model to look for on the Hub, or a local path to the saved adapter config file
and adapter weights.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doc needs to be updated if this can be a state_dict. The name of the argument is not intuitive but I guess wee need to keep it for Backward compatibility? Otherwise would rather have a new arg

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for BC it's fine as the feature is quite new

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really if it was part of the release 😅

src/transformers/integrations/peft.py Show resolved Hide resolved
src/transformers/integrations/peft.py Outdated Show resolved Hide resolved
src/transformers/integrations/peft.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks for iterating and adding a test!

tests/peft_integration/test_peft_integration.py Outdated Show resolved Hide resolved
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
@younesbelkada younesbelkada merged commit 0a55d9f into main Sep 15, 2023
21 checks passed
@younesbelkada younesbelkada deleted the allow_peft_state_dict_loading branch September 15, 2023 16:22
parambharat pushed a commit to parambharat/transformers that referenced this pull request Sep 26, 2023
* Allow PEFT model dict to be loaded

* make style

* make style

* Apply suggestions from code review

* address comments

* fixup

* final change

* added tests

* fix test

* better logic for handling if adapter has been loaded

* Update tests/peft_integration/test_peft_integration.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

---------

Co-authored-by: younesbelkada <younesbelkada@gmail.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
blbadger pushed a commit to blbadger/transformers that referenced this pull request Nov 8, 2023
* Allow PEFT model dict to be loaded

* make style

* make style

* Apply suggestions from code review

* address comments

* fixup

* final change

* added tests

* fix test

* better logic for handling if adapter has been loaded

* Update tests/peft_integration/test_peft_integration.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

---------

Co-authored-by: younesbelkada <younesbelkada@gmail.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
EduardoPach pushed a commit to EduardoPach/transformers that referenced this pull request Nov 18, 2023
* Allow PEFT model dict to be loaded

* make style

* make style

* Apply suggestions from code review

* address comments

* fixup

* final change

* added tests

* fix test

* better logic for handling if adapter has been loaded

* Update tests/peft_integration/test_peft_integration.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

---------

Co-authored-by: younesbelkada <younesbelkada@gmail.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants