Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IP-Adapter] Support multiple IP-Adapters #6573

Merged
merged 40 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
7e1bf9b
allow list
Jan 14, 2024
3b55a1e
update
Jan 14, 2024
345b4d6
save
Jan 14, 2024
f6bae6e
add
Jan 15, 2024
baa7b83
remove print lines
Jan 15, 2024
9024698
style
Jan 15, 2024
27fc796
support multi-image
Jan 17, 2024
67908bf
Merge branch 'main' into multi-ipadapter
yiyixuxu Jan 17, 2024
afd91e3
fix a typo
Jan 17, 2024
7d42455
Merge branch 'multi-ipadapter' of github.com:huggingface/diffusers in…
Jan 17, 2024
1fdcd7b
fix
Jan 17, 2024
45fb582
fix a bug!
yiyixuxu Jan 18, 2024
1a4c6b1
fix
Jan 19, 2024
d924c47
update
Jan 19, 2024
cc2aa1b
fix
Jan 19, 2024
2c86534
Merge branch 'main' into multi-ipadapter
yiyixuxu Jan 19, 2024
0049e44
Apply suggestions from code review
yiyixuxu Jan 24, 2024
4a3df90
ImageProjectionLayers -> image_projection_layers
Jan 24, 2024
ff96407
merge
Jan 24, 2024
193d6e8
fix
Jan 24, 2024
f7f2465
fix-copies
Jan 24, 2024
efa704a
update test
Jan 25, 2024
9abdcf9
add test
Jan 25, 2024
5e47ceb
add prepare_ip_adapter_image_embeds method so pipelines can copy from
Jan 25, 2024
c6670de
update all pipelines support ip-adapter
Jan 25, 2024
711387e
deprecate image_embeds as 3d tensor
Jan 25, 2024
bce309f
corrent num_images_per_prompt behavior
Jan 25, 2024
fae861e
fix batching behavior
Jan 26, 2024
21da205
revert for lcm and sd safe
Jan 26, 2024
accee6b
correct tests
Jan 26, 2024
816578f
update attention processer so backward compatible
Jan 29, 2024
e57103f
revert changes made to ipadapter attention processor to follow the de…
Jan 30, 2024
6cfa34b
Merge branch 'main' into multi-ipadapter
yiyixuxu Jan 30, 2024
1e68c64
add doc
Jan 30, 2024
2cc1561
update doc
Jan 30, 2024
475046e
update docstring
Jan 30, 2024
98fa0c2
add a slow test for multi
Jan 30, 2024
3a52ecb
remove ddim config
Jan 30, 2024
e742cf4
style
Jan 31, 2024
dcdde9c
Merge branch 'main' into multi-ipadapter
yiyixuxu Jan 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 83 additions & 54 deletions src/diffusers/loaders/ip_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pathlib import Path
from typing import Dict, Union
from typing import Dict, List, Union

import torch
from huggingface_hub.utils import validate_hf_hub_args
Expand Down Expand Up @@ -45,14 +46,15 @@ class IPAdapterMixin:
@validate_hf_hub_args
def load_ip_adapter(
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
subfolder: str,
weight_name: str,
pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
subfolder: Union[str, List[str]],
weight_name: Union[str, List[str]],
**kwargs,
):
"""
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
pretrained_model_name_or_path_or_dict (`str` or `
.PathLike` or `dict`):
yiyixuxu marked this conversation as resolved.
Show resolved Hide resolved
Can be either:

- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
Expand Down Expand Up @@ -87,6 +89,26 @@ def load_ip_adapter(
The subfolder location of a model file within a larger model repository on the Hub or locally.
"""

# handle the list inputs for multiple IP Adapters
if not isinstance(weight_name, list):
weight_name = [weight_name]

if not isinstance(pretrained_model_name_or_path_or_dict, list):
pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]
if len(pretrained_model_name_or_path_or_dict) == 1:
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)
Comment on lines +97 to +98
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it cater to the case where you have a single model_id and multiple weight names of different IP Adapters?


if not isinstance(subfolder, list):
subfolder = [subfolder]
if len(subfolder) == 1:
subfolder = subfolder * len(weight_name)
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved

if len(weight_name) != len(pretrained_model_name_or_path_or_dict):
raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.")

if len(weight_name) != len(subfolder):
yiyixuxu marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError("`weight_name` and `subfolder` must have the same length.")

# Load the main state dict first.
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
Expand All @@ -100,61 +122,68 @@ def load_ip_adapter(
"file_type": "attn_procs_weights",
"framework": "pytorch",
}

if not isinstance(pretrained_model_name_or_path_or_dict, dict):
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
if weight_name.endswith(".safetensors"):
state_dict = {"image_proj": {}, "ip_adapter": {}}
with safe_open(model_file, framework="pt", device="cpu") as f:
for key in f.keys():
if key.startswith("image_proj."):
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
elif key.startswith("ip_adapter."):
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
else:
state_dict = torch.load(model_file, map_location="cpu")
else:
state_dict = pretrained_model_name_or_path_or_dict

keys = list(state_dict.keys())
if keys != ["image_proj", "ip_adapter"]:
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")

# load CLIP image encoder here if it has not been registered to the pipeline yet
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
state_dicts = []
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
pretrained_model_name_or_path_or_dict, weight_name, subfolder
):
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
subfolder=Path(subfolder, "image_encoder").as_posix(),
).to(self.device, dtype=self.dtype)
self.image_encoder = image_encoder
self.register_to_config(image_encoder=["transformers", "CLIPVisionModelWithProjection"])
weights_name=weight_name,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
if weight_name.endswith(".safetensors"):
state_dict = {"image_proj": {}, "ip_adapter": {}}
with safe_open(model_file, framework="pt", device="cpu") as f:
for key in f.keys():
if key.startswith("image_proj."):
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
elif key.startswith("ip_adapter."):
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
yiyixuxu marked this conversation as resolved.
Show resolved Hide resolved
else:
state_dict = torch.load(model_file, map_location="cpu")
else:
raise ValueError("`image_encoder` cannot be None when using IP Adapters.")

# create feature extractor if it has not been registered to the pipeline yet
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
self.feature_extractor = CLIPImageProcessor()
self.register_to_config(feature_extractor=["transformers", "CLIPImageProcessor"])

# load ip-adapter into unet
state_dict = pretrained_model_name_or_path_or_dict
Copy link
Member

@sayakpaul sayakpaul Jan 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the case when someone passes the state_dict directly. Shouldn't this statement then be guarded with a check? Something like if isinstance(pretrained_model_name_or_path_or_dict, dict)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is already done

so it's part of this if else

if not isisntance(pretrained_model_name_or_path_or_dict, dict):
        ....
else:
        state_dict = pretrained_model_name_or_path_or_dict


state_dicts.append(state_dict)

keys = list(state_dict.keys())
if keys != ["image_proj", "ip_adapter"]:
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
Comment on lines +155 to +157
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are relying on the state_dict to check if we should raise the error or not. I don't think we have to wait till this position to raise this error. As soon as we have access to the final state_dict, we could raise it.

In this case, that could be before state_dicts.append(). Generally, it's better to raise the errors as early as possible.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved up:)


# load CLIP image encoder here if it has not been registered to the pipeline yet
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
pretrained_model_name_or_path_or_dict,
subfolder=Path(subfolder, "image_encoder").as_posix(),
).to(self.device, dtype=self.dtype)
self.image_encoder = image_encoder
self.register_to_config(image_encoder=["transformers", "CLIPVisionModelWithProjection"])
else:
raise ValueError("`image_encoder` cannot be None when using IP Adapters.")

# create feature extractor if it has not been registered to the pipeline yet
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
self.feature_extractor = CLIPImageProcessor()
self.register_to_config(feature_extractor=["transformers", "CLIPImageProcessor"])
yiyixuxu marked this conversation as resolved.
Show resolved Hide resolved

# load ip-adapter into unet
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
unet._load_ip_adapter_weights(state_dict)
unet._load_ip_adapter_weights(state_dicts)

def set_ip_adapter_scale(self, scale):
if not isinstance(scale, list):
scale = [scale]
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
for attn_processor in unet.attn_processors.values():
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
Expand Down
57 changes: 37 additions & 20 deletions src/diffusers/loaders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@
from huggingface_hub.utils import validate_hf_hub_args
from torch import nn

from ..models.embeddings import ImageProjection, IPAdapterFullImageProjection, IPAdapterPlusImageProjection
from ..models.embeddings import (
ImageProjection,
IPAdapterFullImageProjection,
IPAdapterPlusImageProjection,
MultiIPAdapterImageProjection,
)
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import (
USE_PEFT_BACKEND,
Expand Down Expand Up @@ -761,28 +766,14 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict):
image_projection.load_state_dict(updated_state_dict)
return image_projection

def _load_ip_adapter_weights(self, state_dict):
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

100 percent the right choice!

from ..models.attention_processor import (
AttnProcessor,
AttnProcessor2_0,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
)
Comment on lines 772 to 777
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can move the imports to the top no?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was following the same pattern in this file https://github.com/huggingface/diffusers to/blob/87a92f779c5ba9c180aec4b90c38149eb108d888/src/diffusers/loaders/unet.py#L449

I thought it was to avoid circular import or something, but I'm not really sure. I can look into maybe in a separate PR to see if we can move all the import to top


if "proj.weight" in state_dict["image_proj"]:
# IP-Adapter
num_image_text_embeds = 4
elif "proj.3.weight" in state_dict["image_proj"]:
# IP-Adapter Full Face
num_image_text_embeds = 257 # 256 CLIP tokens + 1 CLS token
else:
# IP-Adapter Plus
num_image_text_embeds = state_dict["image_proj"]["latents"].shape[1]

# Set encoder_hid_proj after loading ip_adapter weights,
# because `IPAdapterPlusImageProjection` also has `attn_processors`.
self.encoder_hid_proj = None
yiyixuxu marked this conversation as resolved.
Show resolved Hide resolved

# set ip-adapter cross-attention processors & load state_dict
attn_procs = {}
key_id = 1
Expand All @@ -796,6 +787,7 @@ def _load_ip_adapter_weights(self, state_dict):
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = self.config.block_out_channels[block_id]

if cross_attention_dim is None or "motion_modules" in name:
attn_processor_class = (
AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
Expand All @@ -805,6 +797,18 @@ def _load_ip_adapter_weights(self, state_dict):
attn_processor_class = (
IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
)
num_image_text_embeds = []
for state_dict in state_dicts:
if "proj.weight" in state_dict["image_proj"]:
# IP-Adapter
num_image_text_embeds += [4]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's give "4" a variable.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is yet to be addressed?

elif "proj.3.weight" in state_dict["image_proj"]:
# IP-Adapter Full Face
num_image_text_embeds += [257] # 256 CLIP tokens + 1 CLS token
else:
# IP-Adapter Plus
num_image_text_embeds += [state_dict["image_proj"]["latents"].shape[1]]

attn_procs[name] = attn_processor_class(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
Expand All @@ -813,16 +817,29 @@ def _load_ip_adapter_weights(self, state_dict):
).to(dtype=self.dtype, device=self.device)

value_dict = {}
for k, w in attn_procs[name].state_dict().items():
value_dict.update({f"{k}": state_dict["ip_adapter"][f"{key_id}.{k}"]})
yiyixuxu marked this conversation as resolved.
Show resolved Hide resolved
for i, state_dict in enumerate(state_dicts):
value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]})
value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]})

attn_procs[name].load_state_dict(value_dict)
key_id += 2

return attn_procs

def _load_ip_adapter_weights(self, state_dicts):
# Set encoder_hid_proj after loading ip_adapter weights,
# because `IPAdapterPlusImageProjection` also has `attn_processors`.
self.encoder_hid_proj = None
Comment on lines +834 to +836
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I don't understand the comment fully. Could you elaborate?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't 😛 I assume it's notes from contributor of Ip adapter plus


attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts)
self.set_attn_processor(attn_procs)

# convert IP-Adapter Image Projection layers to diffusers
image_projection = self._convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"])
image_projection_layers = []
for state_dict in state_dicts:
image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"])
image_projection_layer.to(device=self.device, dtype=self.dtype)
image_projection_layers.append(image_projection_layer)
Comment on lines +842 to +846
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice delegation. First convert to attention procs and then handle the rest of the stuff like projection with a dedicated method.


self.encoder_hid_proj = image_projection.to(device=self.device, dtype=self.dtype)
self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
self.config.encoder_hid_dim_type = "ip_image_proj"
Loading
Loading