-
Notifications
You must be signed in to change notification settings - Fork 71
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add IP-adapter support for stable diffusion (#766)
# What does this PR do? This adds support for Stable Diffusion IP-adapters. IP-adapter (Image Prompt adapter) is a Stable Diffusion add-on for using images as prompts, similar to Midjourney and DaLLE 3. You can use it to copy the style, composition, or a face in the reference image. Fixes #718 - [x] Export of sd models loaded with loaded IP adapter weights + image encoder - [x] Ensure the caching works - [x] Inference: add image encoder to the pipelines * Export via CLI ```bash optimum-cli export neuron --model stable-diffusion-v1-5/stable-diffusion-v1-5 --ip_adapter_id h94/IP-Adapter --ip_adapter_subfolder models --ip_adapter_weight_name ip-adapter-full-face_sd15.bin --ip_adapter_scale 0.5 --batch_size 1 --height 512 --width 512 --num_images_per_prompt 1 --auto_cast matmul --auto_cast_type bf16 ip_adapter_neuron/ ``` * Export via NeuronModel API ```python from optimum.neuron import NeuronStableDiffusionPipeline model_id = "runwayml/stable-diffusion-v1-5" compiler_args = {"auto_cast": "matmul", "auto_cast_type": "bf16"} input_shapes = {"batch_size": 1, "height": 512, "width": 512} stable_diffusion = NeuronStableDiffusionPipeline.from_pretrained( model_id, export=True, ip_adapter_id="h94/IP-Adapter", ip_adapter_subfolder="models", ip_adapter_weight_name="ip-adapter-full-face_sd15.bin", ip_adapter_scale=0.5, **compiler_args, **input_shapes, ) # Save locally or upload to the HuggingFace Hub save_directory = "ip_adapter_neuron/" stable_diffusion.save_pretrained(save_directory) ``` * Inference * With `ip_adapter_image` as input ```python from optimum.neuron import NeuronStableDiffusionPipeline model_id = "runwayml/stable-diffusion-v1-5" compiler_args = {"auto_cast": "matmul", "auto_cast_type": "bf16"} input_shapes = {"batch_size": 1, "height": 512, "width": 512} stable_diffusion = NeuronStableDiffusionPipeline.from_pretrained( model_id, export=True, ip_adapter_id="h94/IP-Adapter", ip_adapter_subfolder="models", ip_adapter_weight_name="ip-adapter-full-face_sd15.bin", ip_adapter_scale=0.5, **compiler_args, **input_shapes, ) # Save locally or upload to the HuggingFace Hub save_directory = "ip_adapter_neuron/" stable_diffusion.save_pretrained(save_directory) ``` * With `ip_adapter_image_embeds` as input (encode the image first) ```python image_embeds = stable_diffusion.prepare_ip_adapter_image_embeds( ip_adapter_image=image, ip_adapter_image_embeds=None, device=None, num_images_per_prompt=1, do_classifier_free_guidance=True, ) torch.save(image_embeds, "image_embeds.ipadpt") image_embeds = torch.load("image_embeds.ipadpt") images = stable_diffusion( prompt="a polar bear sitting in a chair drinking a milkshake", ip_adapter_image_embeds=image_embeds, negative_prompt="deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality", num_inference_steps=100, generator=generator, ).images[0] image.save("polar_bear.png") ``` ### Next steps * Support multiple IP adapters * Ensure it works for sdxl * Extend the support for diffusion transformers * Documentation along with refactoring ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you make sure to update the documentation with your changes? - [ ] Did you write any new necessary tests?
- Loading branch information
1 parent
2f06ce6
commit 6786b8c
Showing
17 changed files
with
642 additions
and
261 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.