Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IP-Adapter] Support multiple IP-Adapters #6573

Merged
merged 40 commits into from
Jan 31, 2024
Merged

[IP-Adapter] Support multiple IP-Adapters #6573

merged 40 commits into from
Jan 31, 2024

Conversation

yiyixuxu
Copy link
Collaborator

@yiyixuxu yiyixuxu commented Jan 15, 2024

initial draft for multiple IP-Adapter support

for #6318
see the discussion thread here #6544

to-do

  • add multi-adapter support
  • add multi-image support
  • refactor
  • doc and tests

working now! thanks to @asomoza

testing multi-adapter and multi-image

testing script

# yiyi testing script for multi-ipadapter: face + style folder
import torch
from diffusers import AutoPipelineForText2Image, DDIMScheduler
from transformers import CLIPVisionModelWithProjection
from diffusers.utils import load_image

noise_scheduler = DDIMScheduler(
    num_train_timesteps=1000,
    beta_start=0.00085,
    beta_end=0.012,
    beta_schedule="scaled_linear",
    clip_sample=False,
    set_alpha_to_one=False,
    steps_offset=1
)

image_encoder = CLIPVisionModelWithProjection.from_pretrained(
    "h94/IP-Adapter", 
    subfolder="models/image_encoder",
    torch_dtype=torch.float16,
)

pipeline = AutoPipelineForText2Image.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16,
    scheduler=noise_scheduler,
    image_encoder=image_encoder,
)

pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name=["ip-adapter-plus_sdxl_vit-h.safetensors", "ip-adapter-plus-face_sdxl_vit-h.safetensors"])
pipeline.set_ip_adapter_scale([0.7, 0.3])

pipeline.enable_model_cpu_offload()

face_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/women_input.png")

style_folder = "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy"
style_images =  [load_image(f"{style_folder}/img{i}.png") for i in range(10)]

generator = torch.Generator(device="cpu").manual_seed(0)

image = pipeline(
    prompt="wonderwoman",
    ip_adapter_image=[style_images, face_image],
    negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality", 
    num_inference_steps=50, num_images_per_prompt=1,
    generator=generator,
).images[0]
image.save(f"yiyi_test_out.png")

image inputs

face image
women_input
stype images
style_grid

output

yiyi_test_7_out

slow test

testing previous API with single ip-adapter and single image input. below test script generate identical results on main and PR branch

from diffusers import AutoPipelineForText2Image
import torch
from diffusers.utils import load_image


pipeline = AutoPipelineForText2Image.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda")
# ip-adapter

ip_adapter_weights = {
    "ip-adapter": "ip-adapter_sd15.bin",
    "ip-adapter-plus": "ip-adapter-plus_sd15.bin",
    "ip-adapter-full-face": "ip-adapter-full-face_sd15.bin",
}

for adapter_name, weight_name in ip_adapter_weights.items():
    pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name=weight_name)
    pipeline.set_ip_adapter_scale(0.6)
    image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_neg_embed.png")
    generator = torch.Generator(device="cpu").manual_seed(33)
    images = pipeline(
        prompt='best quality, high quality, wearing sunglasses',
        ip_adapter_image=image,
        negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality", 
        num_inference_steps=50,
        num_images_per_prompt=1,
        generator=generator,
    ).images
    images[0].save(f"yiyi_test_2_out_{adapter_name}.png")
ip-adapter ip-adapter-full ip-adapter-flus
yiyi_test_2_out_ip-adapter(pr) yiyi_test_2_out_ip-adapter-full-face(pr) yiyi_test_2_out_ip-adapter-plus

testing batch generation

The current implementation does not work with batch - it won't work with multiple prompts or num_iamges_per_prompt > 1. This is fixed in this PR

testing script

# yiyi testing script for multi-ipadapter: face + style folder
import torch
from diffusers import AutoPipelineForText2Image, DDIMScheduler
from transformers import CLIPVisionModelWithProjection
from diffusers.utils import load_image, make_image_grid

noise_scheduler = DDIMScheduler(
    num_train_timesteps=1000,
    beta_start=0.00085,
    beta_end=0.012,
    beta_schedule="scaled_linear",
    clip_sample=False,
    set_alpha_to_one=False,
    steps_offset=1
)

image_encoder = CLIPVisionModelWithProjection.from_pretrained(
    "h94/IP-Adapter", 
    subfolder="models/image_encoder",
    torch_dtype=torch.float16,
)

pipeline = AutoPipelineForText2Image.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16,
    scheduler=noise_scheduler,
    image_encoder=image_encoder,
)

pipeline.load_ip_adapter(["h94/IP-Adapter"], subfolder=["sdxl_models"], weight_name=["ip-adapter-plus_sdxl_vit-h.safetensors", "ip-adapter-plus-face_sdxl_vit-h.safetensors"])
pipeline.set_ip_adapter_scale([0.7, 0.3])

pipeline.enable_model_cpu_offload()

face_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/women_input.png")

style_folder = "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy"
style_images =  [load_image(f"{style_folder}/img{i}.png") for i in range(10)]

generator = torch.Generator(device="cpu").manual_seed(33)

images = pipeline(
    prompt=["batman", "wonderwoman"],
    ip_adapter_image=[style_images, face_image],
    negative_prompt=["monochrome, lowres, bad anatomy, worst quality, low quality"] * 2, 
    num_inference_steps=50, num_images_per_prompt=2,
    generator=generator,
).images

make_image_grid(images, rows=2, cols=2).save(f"yiyi_test_out.png")

yiyi_test_7_out

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@asomoza
Copy link
Member

asomoza commented Jan 17, 2024

if you use the multiple images for style, does it work? with just 4 images and the "wonder woman" prompt, you should get something like this:

Screenshot 2024-01-17 080446

for style I usually use the normal "not plus" adapter, also the face adapter is really strong it could be changing the style. Just looking at the face it looks like even at 0.3 is still too strong, the face should be more "animated" like this one:

ComfyUI_00145_

your result looks more real and more like if I put the scale to 1.0 and with a mask:

ComfyUI_00151_

to me it looks like the second adapter is overwriting the first one.

@yiyixuxu
Copy link
Collaborator Author

thanks @asomoza
super valuable insights! yeah my results are off, either I missed something or had a bug somewhere. Going to dig into it more.

I also saw on invokeAI's code, there are fields begin_step_percentand end_step_percent - I did not implement this feature. do you think it could be that?

@asomoza
Copy link
Member

asomoza commented Jan 17, 2024

IMO that shouldn't matter because they're doing that in the pipeline, they check the %, convert it to steps and simply set the scale to 0.0 before and after. It's a nice feature to have though but that can be easily added later.

I'm still struggling to understand the diffusers implementation, maybe I'm missing some of the diffusers coding practices, but the image_embeds shouldn't be always be passed in the cross_attention_kwargs as you're doing in this PR? It would be even better to take out the ImageProjection of the unet forward, at least I don't see the need to do it in every step when it can be done with the image encoding, I'm not 100% sure since I don't use the default diffusers code.

@yiyixuxu
Copy link
Collaborator Author

@asomoza
we are going to refactor the design once got the correct results:)

It would be even better to take out the ImageProjection of the unet forward, at least I don't see the need to do it in every step when it can be done with the image encoding, I'm not 100% sure since I don't use the default diffusers code.

We did it that way to be consistent with the unet design in rest of our codebase. I think it makes less sense now with multi-image support. We are considering refactoring our unet to separate these projection layers as well.

@asomoza
Copy link
Member

asomoza commented Jan 17, 2024

oh I see, thanks for the response, I'm also currently implementing the multi ip adapters myself, If I find something useful I'll let you know.

@russmaschmeyer
Copy link

Super excited you're looking at adding these features. I've found chained IP Adapters very useful in ComfyUI for generating new backgrounds for an existing photo subject. ComfyUI also has a Conditioning (Set Mask) node that allows you to set a drawing mask for the text prompt as well as a Conditioning (Combine) node to bring foreground and background prompt/controlnet conditioning together into a single positive prompt which then gets passed into the sampler.

Been researching furiously but can't find the equivalent of that conditioning masking and combining capability in Diffusers. Are there equivalents in Diffusers?

@asomoza
Copy link
Member

asomoza commented Jan 18, 2024

@russmaschmeyer I did bring up the masking for IP Adapters, but as I understood with the comments, most people don't use them, they use diffusers as a fast and simple way of generation so the masking is delegated to community pipelines, this is somehow related to this PR but mostly not, so I think it would be better to ask this question in the discussions tab.

IMO what you're looking for is a combination of regional prompting with masks for controlnet and IP Adapters which is the best combination for image composition, I eventually need to make this sooner or later and maybe I'll have the time to port it to a diffusers pipeline if no one has done it at that time.

@asomoza
Copy link
Member

asomoza commented Jan 18, 2024

@yiyixuxu I'm done with my implementation and I got it working, this is the result:

face + style(11 images) face + style PLUS(11 images)
20240118042242 20240118042625

Reading your code, I think your problem is in the init of the IP Adapter Attention Processor, I use this:

self.to_k_ip = nn.ModuleList([nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))])
self.to_v_ip = nn.ModuleList([nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))])

Co-authored-by: Alvaro Somoza <somoza.alvaro@gmail.com>
@yiyixuxu
Copy link
Collaborator Author

@asomoza
thank you! you're right!! I had a bug there. Thanks so much for looking into this :) ❤️

@asomoza
Copy link
Member

asomoza commented Jan 18, 2024

Glad to be of help, I've been doing tests with multiple combinations and the results are the same as ComfyUI. Now I'll wait for the final refactor to see how much I need to deviate from the diffusers code (hopefully not that much). I still need to add the start and end % and masking options afterwards.

@russmaschmeyer
Copy link

IMO what you're looking for is a combination of regional prompting with masks for controlnet and IP Adapters which is the best combination for image composition, I eventually need to make this sooner or later and maybe I'll have the time to port it to a diffusers pipeline if no one has done it at that time.

YES! That nails it. Will research regional prompting solutions. FWIW I have found IP Adapter attention masking to be VERY powerful, so I'm happy to add another voice to the "it would be amazing if that were supported in Diffusers!" vote. In ComfyUI I've found that regional prompting alone (with masks) doesn't get you there. You need both IP attention masking AND conditioning masking.

Thanks for the tip @asomoza!

@@ -763,28 +768,14 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict):
image_projection.load_state_dict(updated_state_dict)
return image_projection

def _load_ip_adapter_weights(self, state_dict):
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

100 percent the right choice!

Comment on lines 772 to 777
from ..models.attention_processor import (
AttnProcessor,
AttnProcessor2_0,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can move the imports to the top no?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was following the same pattern in this file https://github.com/huggingface/diffusers to/blob/87a92f779c5ba9c180aec4b90c38149eb108d888/src/diffusers/loaders/unet.py#L449

I thought it was to avoid circular import or something, but I'm not really sure. I can look into maybe in a separate PR to see if we can move all the import to top

Comment on lines +832 to +833
if not isinstance(state_dicts, list):
state_dicts = [state_dicts]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it to ensure BW compatibility? I don't see how state_dicts could not be a list.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes it is

Comment on lines +834 to +836
# Set encoder_hid_proj after loading ip_adapter weights,
# because `IPAdapterPlusImageProjection` also has `attn_processors`.
self.encoder_hid_proj = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I don't understand the comment fully. Could you elaborate?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't 😛 I assume it's notes from contributor of Ip adapter plus

Comment on lines +842 to +846
image_projection_layers = []
for state_dict in state_dicts:
image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"])
image_projection_layer.to(device=self.device, dtype=self.dtype)
image_projection_layers.append(image_projection_layer)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice delegation. First convert to attention procs and then handle the rest of the stuff like projection with a dedicated method.

residual = hidden_states

# separate ip_hidden_states from encoder_hidden_states
if encoder_hidden_states is not None:
if isinstance(encoder_hidden_states, tuple):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BW compatibility?

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking solid! I only left a handful of comments.

(not merge-blocking): Do we need to update all those pipelines in this PR? Maybe we could open it up for the community?

@yiyixuxu yiyixuxu merged commit 2e8d18e into main Jan 31, 2024
16 checks passed
@yiyixuxu yiyixuxu deleted the multi-ipadapter branch February 1, 2024 16:57
@okotaku okotaku mentioned this pull request Feb 2, 2024
4 tasks
dg845 pushed a commit to dg845/diffusers that referenced this pull request Feb 2, 2024
---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
Co-authored-by: Alvaro Somoza <somoza.alvaro@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
@alexblattner
Copy link

any plans to add face id? I'd love to use face plus with face id but can't with the community pipeline

@alexblattner
Copy link

also, is there a from_single_file equivalent for the load_ip_adapter function?

@sayakpaul
Copy link
Member

also, is there a from_single_file equivalent for the load_ip_adapter function?

I don't think there's any need as we directly load the original single-file checkpoint of IP Adapter from the get-go.

@alexblattner
Copy link

@sayakpaul what if I have my own trained ip adapter? What if I want to use models from 2 different users?

@sayakpaul
Copy link
Member

If they follow the original IP Adapter format (example), then load_ip_adapter() method already work.

@alexblattner
Copy link

yes, but for 1 only. You can't use multi ip adapters from different sources with the current implementation

@sayakpaul
Copy link
Member

Help us with a reproducible snippet in a new thread.

@alexblattner
Copy link

I was only thinking of using ip_adapter_face_plus with ip_adapter_face_id that aren't in the same directory. it seems like it would be a pain in the ass the implement for you with the current design

@okaris
Copy link

okaris commented Apr 24, 2024

@alexblattner you should be able to use it like:

pipeline.load_ip_adapter(["your-hf-account/IP-Adapter-1", "other-hf-account/IP-Adapter-2"], subfolder=["sdxl_models", "sdxl_models"], weight_name=["ip-adapter-plus_sdxl_vit-h.safetensors", "ip-adapter-plus-face_sdxl_vit-h.safetensors"])

AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
Co-authored-by: Alvaro Somoza <somoza.alvaro@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
@alexblattner
Copy link

@okaris doesn't work with faceid

@okaris
Copy link

okaris commented May 1, 2024

@alexblattner can you send a sample code please?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.