-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PEFT] Allow PEFT model dict to be loaded #25721
Conversation
c9e7fb8
to
032bde0
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR.
I'm not very knowledgeable about diffusers, so will let others comment on the overall solution. My comments are just minor issues, not blockers.
@@ -28,6 +29,10 @@ | |||
from accelerate.utils import get_balanced_memory, infer_auto_device_map | |||
|
|||
|
|||
if is_torch_available(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this necessary?
@@ -59,14 +64,15 @@ class PeftAdapterMixin: | |||
|
|||
def load_adapter( | |||
self, | |||
peft_model_id: str, | |||
peft_model_id: Union[str, Dict[str, "torch.Tensor"]], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we choose a different name, now that a state dict can be passed? Alternatively, we could add another (optional) argument to pass the state dict, make peft_model_id
optional, and do a check that exactly one of the two should be passed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes makes sense to me! peft_model_id_or_state_dict
sounds great - was just wondering about breaking change here. But think this function is not yet in a release so should probs be fine to change no?
@@ -75,7 +81,7 @@ def load_adapter( | |||
Requires peft as a backend to load the adapter weights. | |||
|
|||
Args: | |||
peft_model_id (`str`): | |||
peft_model_id (`str` or dictionary of `torch.Tensor`): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The description below should also be adjusted.
adapter_name: Optional[str] = None, | ||
revision: Optional[str] = None, | ||
token: Optional[str] = None, | ||
device_map: Optional[str] = "auto", | ||
max_memory: Optional[str] = None, | ||
offload_folder: Optional[str] = None, | ||
offload_index: Optional[int] = None, | ||
peft_config: Dict[str, Any] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
peft_config: Dict[str, Any] = None, | |
peft_config: Optional[Dict[str, Any]] = None, |
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! agreed also with @BenjaminBossan 's comments, I think we should maybe add an extra check and raise a proper error, for the naming, what about peft_model_id_or_state_dict
(maybe that's too long) ?
raise ValueError( | ||
f"adapter model file not found in {peft_model_id}. Make sure you are passing the correct path to the " | ||
"adapter model." | ||
if peft_config is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if peft_config is None: | |
if peft_config is None and isinstance(peft_model_id, str): |
and add a check below that if peft_model_id is not a state dict, raise an error
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @patrickvonplaten for adding the support for passing state dict and config, LGTM!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks again !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have one question regarding the peft_model_id
being a dictionary, otherwise looks good
peft_model_id (`str` or dictionary of `torch.Tensor`): | ||
The identifier of the model to look for on the Hub, or a local path to the saved adapter config file | ||
and adapter weights. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doc needs to be updated if this can be a state_dict
. The name of the argument is not intuitive but I guess wee need to keep it for Backward compatibility? Otherwise would rather have a new arg
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think for BC it's fine as the feature is quite new
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not really if it was part of the release 😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM thanks for iterating and adding a test!
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* Allow PEFT model dict to be loaded * make style * make style * Apply suggestions from code review * address comments * fixup * final change * added tests * fix test * better logic for handling if adapter has been loaded * Update tests/peft_integration/test_peft_integration.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Co-authored-by: younesbelkada <younesbelkada@gmail.com> Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* Allow PEFT model dict to be loaded * make style * make style * Apply suggestions from code review * address comments * fixup * final change * added tests * fix test * better logic for handling if adapter has been loaded * Update tests/peft_integration/test_peft_integration.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Co-authored-by: younesbelkada <younesbelkada@gmail.com> Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* Allow PEFT model dict to be loaded * make style * make style * Apply suggestions from code review * address comments * fixup * final change * added tests * fix test * better logic for handling if adapter has been loaded * Update tests/peft_integration/test_peft_integration.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Co-authored-by: younesbelkada <younesbelkada@gmail.com> Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
In order to allow
peft
to be leveraged indiffusers
without breaking changes we need to allow loading adapters directly from a loadedstate_dict
. The reason is that indiffusers
we currently store LoRA checkpoints in a format that is different to the PEFT format so we cannot just pass the model_id. This PR allows the user to manually pass a loaded PEFT model checkpoint as well as a PEFT configuration, thus circumventing the need to pass a model id.In pseudo code, the integration of
transformers
+ PEFT indiffusers
should then look as follows the "load_lora" function ofdiffusers
.Note, there might be more changes we have to do to PEFT,
transformers'
PEFT integration to be sure that everything works as expected. E.g. I'm not yet sure how to passnetwork_alphas
etc... to PEFT to make sure we get 1-to-1 the same result.cc @younesbelkada @sayakpaul @BenjaminBossan @pacman100