Skip to content

Commit

Permalink
[IP-Adapter] Support multiple IP-Adapters (huggingface#6573)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
Co-authored-by: Alvaro Somoza <somoza.alvaro@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
  • Loading branch information
4 people authored and dg845 committed Feb 2, 2024
1 parent 64533cd commit 8fac00f
Show file tree
Hide file tree
Showing 25 changed files with 895 additions and 235 deletions.
73 changes: 61 additions & 12 deletions docs/source/en/using-diffusers/loading_adapters.md
Original file line number Diff line number Diff line change
Expand Up @@ -506,22 +506,11 @@ import torch
from diffusers import StableDiffusionPipeline, DDIMScheduler
from diffusers.utils import load_image

noise_scheduler = DDIMScheduler(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
steps_offset=1
)

pipeline = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16,
scheduler=noise_scheduler,
).to("cuda")

pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-full-face_sd15.bin")

pipeline.set_ip_adapter_scale(0.7)
Expand Down Expand Up @@ -550,6 +539,66 @@ image = pipeline(
</div>
</div>


You can load multiple IP-Adapter models and use multiple reference images at the same time. In this example we use IP-Adapter-Plus face model to create a consistent character and also use IP-Adapter-Plus model along with 10 images to create a coherent style in the image we generate.

```python
import torch
from diffusers import AutoPipelineForText2Image, DDIMScheduler
from transformers import CLIPVisionModelWithProjection
from diffusers.utils import load_image

image_encoder = CLIPVisionModelWithProjection.from_pretrained(
"h94/IP-Adapter",
subfolder="models/image_encoder",
torch_dtype=torch.float16,
)

pipeline = AutoPipelineForText2Image.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
image_encoder=image_encoder,
)
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
pipeline.load_ip_adapter(
"h94/IP-Adapter",
subfolder="sdxl_models",
weight_name=["ip-adapter-plus_sdxl_vit-h.safetensors", "ip-adapter-plus-face_sdxl_vit-h.safetensors"]
)
pipeline.set_ip_adapter_scale([0.7, 0.3])
pipeline.enable_model_cpu_offload()

face_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/women_input.png")
style_folder = "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy"
style_images = [load_image(f"{style_folder}/img{i}.png") for i in range(10)]

generator = torch.Generator(device="cpu").manual_seed(0)

image = pipeline(
prompt="wonderwoman",
ip_adapter_image=[style_images, face_image],
negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
num_inference_steps=50, num_images_per_prompt=1,
generator=generator,
).images[0]
```
<div class="flex justify-center">
    <img src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_style_grid.png" />
<figcaption class="mt-2 text-center text-sm text-gray-500">style input image</figcaption>
</div>

<div class="flex flex-row gap-4">
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/women_input.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">face input image</figcaption>
</div>
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_multi_out.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">output image</figcaption>
</div>
</div>


### LCM-Lora

You can use IP-Adapter with LCM-Lora to achieve "instant fine-tune" with custom images. Note that you need to load IP-Adapter weights before loading the LCM-Lora weights.
Expand Down
134 changes: 81 additions & 53 deletions src/diffusers/loaders/ip_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pathlib import Path
from typing import Dict, Union
from typing import Dict, List, Union

import torch
from huggingface_hub.utils import validate_hf_hub_args
Expand Down Expand Up @@ -45,9 +46,9 @@ class IPAdapterMixin:
@validate_hf_hub_args
def load_ip_adapter(
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
subfolder: str,
weight_name: str,
pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
subfolder: Union[str, List[str]],
weight_name: Union[str, List[str]],
**kwargs,
):
"""
Expand Down Expand Up @@ -87,6 +88,26 @@ def load_ip_adapter(
The subfolder location of a model file within a larger model repository on the Hub or locally.
"""

# handle the list inputs for multiple IP Adapters
if not isinstance(weight_name, list):
weight_name = [weight_name]

if not isinstance(pretrained_model_name_or_path_or_dict, list):
pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]
if len(pretrained_model_name_or_path_or_dict) == 1:
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)

if not isinstance(subfolder, list):
subfolder = [subfolder]
if len(subfolder) == 1:
subfolder = subfolder * len(weight_name)

if len(weight_name) != len(pretrained_model_name_or_path_or_dict):
raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.")

if len(weight_name) != len(subfolder):
raise ValueError("`weight_name` and `subfolder` must have the same length.")

# Load the main state dict first.
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
Expand All @@ -100,61 +121,68 @@ def load_ip_adapter(
"file_type": "attn_procs_weights",
"framework": "pytorch",
}

if not isinstance(pretrained_model_name_or_path_or_dict, dict):
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
if weight_name.endswith(".safetensors"):
state_dict = {"image_proj": {}, "ip_adapter": {}}
with safe_open(model_file, framework="pt", device="cpu") as f:
for key in f.keys():
if key.startswith("image_proj."):
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
elif key.startswith("ip_adapter."):
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
else:
state_dict = torch.load(model_file, map_location="cpu")
else:
state_dict = pretrained_model_name_or_path_or_dict

keys = list(state_dict.keys())
if keys != ["image_proj", "ip_adapter"]:
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")

# load CLIP image encoder here if it has not been registered to the pipeline yet
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
state_dicts = []
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
pretrained_model_name_or_path_or_dict, weight_name, subfolder
):
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
subfolder=Path(subfolder, "image_encoder").as_posix(),
).to(self.device, dtype=self.dtype)
self.image_encoder = image_encoder
self.register_to_config(image_encoder=["transformers", "CLIPVisionModelWithProjection"])
weights_name=weight_name,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
if weight_name.endswith(".safetensors"):
state_dict = {"image_proj": {}, "ip_adapter": {}}
with safe_open(model_file, framework="pt", device="cpu") as f:
for key in f.keys():
if key.startswith("image_proj."):
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
elif key.startswith("ip_adapter."):
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
else:
state_dict = torch.load(model_file, map_location="cpu")
else:
raise ValueError("`image_encoder` cannot be None when using IP Adapters.")

# create feature extractor if it has not been registered to the pipeline yet
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
self.feature_extractor = CLIPImageProcessor()
self.register_to_config(feature_extractor=["transformers", "CLIPImageProcessor"])

# load ip-adapter into unet
state_dict = pretrained_model_name_or_path_or_dict

keys = list(state_dict.keys())
if keys != ["image_proj", "ip_adapter"]:
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")

state_dicts.append(state_dict)

# load CLIP image encoder here if it has not been registered to the pipeline yet
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
pretrained_model_name_or_path_or_dict,
subfolder=Path(subfolder, "image_encoder").as_posix(),
).to(self.device, dtype=self.dtype)
self.image_encoder = image_encoder
self.register_to_config(image_encoder=["transformers", "CLIPVisionModelWithProjection"])
else:
raise ValueError("`image_encoder` cannot be None when using IP Adapters.")

# create feature extractor if it has not been registered to the pipeline yet
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
feature_extractor = CLIPImageProcessor()
self.register_modules(feature_extractor=feature_extractor)

# load ip-adapter into unet
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
unet._load_ip_adapter_weights(state_dict)
unet._load_ip_adapter_weights(state_dicts)

def set_ip_adapter_scale(self, scale):
if not isinstance(scale, list):
scale = [scale]
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
for attn_processor in unet.attn_processors.values():
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
Expand Down
59 changes: 39 additions & 20 deletions src/diffusers/loaders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@
from huggingface_hub.utils import validate_hf_hub_args
from torch import nn

from ..models.embeddings import ImageProjection, IPAdapterFullImageProjection, IPAdapterPlusImageProjection
from ..models.embeddings import (
ImageProjection,
IPAdapterFullImageProjection,
IPAdapterPlusImageProjection,
MultiIPAdapterImageProjection,
)
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import (
USE_PEFT_BACKEND,
Expand Down Expand Up @@ -763,28 +768,14 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict):
image_projection.load_state_dict(updated_state_dict)
return image_projection

def _load_ip_adapter_weights(self, state_dict):
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts):
from ..models.attention_processor import (
AttnProcessor,
AttnProcessor2_0,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
)

if "proj.weight" in state_dict["image_proj"]:
# IP-Adapter
num_image_text_embeds = 4
elif "proj.3.weight" in state_dict["image_proj"]:
# IP-Adapter Full Face
num_image_text_embeds = 257 # 256 CLIP tokens + 1 CLS token
else:
# IP-Adapter Plus
num_image_text_embeds = state_dict["image_proj"]["latents"].shape[1]

# Set encoder_hid_proj after loading ip_adapter weights,
# because `IPAdapterPlusImageProjection` also has `attn_processors`.
self.encoder_hid_proj = None

# set ip-adapter cross-attention processors & load state_dict
attn_procs = {}
key_id = 1
Expand All @@ -798,6 +789,7 @@ def _load_ip_adapter_weights(self, state_dict):
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = self.config.block_out_channels[block_id]

if cross_attention_dim is None or "motion_modules" in name:
attn_processor_class = (
AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
Expand All @@ -807,6 +799,18 @@ def _load_ip_adapter_weights(self, state_dict):
attn_processor_class = (
IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
)
num_image_text_embeds = []
for state_dict in state_dicts:
if "proj.weight" in state_dict["image_proj"]:
# IP-Adapter
num_image_text_embeds += [4]
elif "proj.3.weight" in state_dict["image_proj"]:
# IP-Adapter Full Face
num_image_text_embeds += [257] # 256 CLIP tokens + 1 CLS token
else:
# IP-Adapter Plus
num_image_text_embeds += [state_dict["image_proj"]["latents"].shape[1]]

attn_procs[name] = attn_processor_class(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
Expand All @@ -815,16 +819,31 @@ def _load_ip_adapter_weights(self, state_dict):
).to(dtype=self.dtype, device=self.device)

value_dict = {}
for k, w in attn_procs[name].state_dict().items():
value_dict.update({f"{k}": state_dict["ip_adapter"][f"{key_id}.{k}"]})
for i, state_dict in enumerate(state_dicts):
value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]})
value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]})

attn_procs[name].load_state_dict(value_dict)
key_id += 2

return attn_procs

def _load_ip_adapter_weights(self, state_dicts):
if not isinstance(state_dicts, list):
state_dicts = [state_dicts]
# Set encoder_hid_proj after loading ip_adapter weights,
# because `IPAdapterPlusImageProjection` also has `attn_processors`.
self.encoder_hid_proj = None

attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts)
self.set_attn_processor(attn_procs)

# convert IP-Adapter Image Projection layers to diffusers
image_projection = self._convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"])
image_projection_layers = []
for state_dict in state_dicts:
image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"])
image_projection_layer.to(device=self.device, dtype=self.dtype)
image_projection_layers.append(image_projection_layer)

self.encoder_hid_proj = image_projection.to(device=self.device, dtype=self.dtype)
self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
self.config.encoder_hid_dim_type = "ip_image_proj"
Loading

0 comments on commit 8fac00f

Please sign in to comment.