Skip to content

Commit

Permalink
Add IP-adapter support for stable diffusion (#766)
Browse files Browse the repository at this point in the history
# What does this PR do?

This adds support for Stable Diffusion IP-adapters.

IP-adapter (Image Prompt adapter) is a Stable Diffusion add-on for using
images as prompts, similar to Midjourney and DaLLE 3. You can use it to
copy the style, composition, or a face in the reference image.

Fixes #718

- [x] Export of sd models loaded with loaded IP adapter weights + image
encoder
- [x] Ensure the caching works
- [x] Inference: add image encoder to the pipelines

* Export via CLI

```bash
optimum-cli export neuron --model stable-diffusion-v1-5/stable-diffusion-v1-5 
                          --ip_adapter_id h94/IP-Adapter 
                          --ip_adapter_subfolder models
                          --ip_adapter_weight_name ip-adapter-full-face_sd15.bin
                          --ip_adapter_scale 0.5
                          --batch_size 1 --height 512 --width 512 --num_images_per_prompt 1
                          --auto_cast matmul --auto_cast_type bf16 ip_adapter_neuron/
```

* Export via NeuronModel API

```python
from optimum.neuron import NeuronStableDiffusionPipeline

model_id = "runwayml/stable-diffusion-v1-5"
compiler_args = {"auto_cast": "matmul", "auto_cast_type": "bf16"}
input_shapes = {"batch_size": 1, "height": 512, "width": 512}

stable_diffusion = NeuronStableDiffusionPipeline.from_pretrained(
    model_id, 
    export=True, 
    ip_adapter_id="h94/IP-Adapter",
    ip_adapter_subfolder="models",
    ip_adapter_weight_name="ip-adapter-full-face_sd15.bin",
    ip_adapter_scale=0.5,
    **compiler_args, 
    **input_shapes,
)

# Save locally or upload to the HuggingFace Hub
save_directory = "ip_adapter_neuron/"
stable_diffusion.save_pretrained(save_directory)
```

* Inference
    * With `ip_adapter_image` as input
    ```python
    from optimum.neuron import NeuronStableDiffusionPipeline
    
    model_id = "runwayml/stable-diffusion-v1-5"
    compiler_args = {"auto_cast": "matmul", "auto_cast_type": "bf16"}
    input_shapes = {"batch_size": 1, "height": 512, "width": 512}
    
    stable_diffusion = NeuronStableDiffusionPipeline.from_pretrained(
        model_id, 
        export=True, 
        ip_adapter_id="h94/IP-Adapter",
        ip_adapter_subfolder="models",
        ip_adapter_weight_name="ip-adapter-full-face_sd15.bin",
        ip_adapter_scale=0.5,
        **compiler_args, 
        **input_shapes,
    )
    
    # Save locally or upload to the HuggingFace Hub
    save_directory = "ip_adapter_neuron/"
    stable_diffusion.save_pretrained(save_directory)
    ```
    
    * With `ip_adapter_image_embeds` as input (encode the image first)
    ```python
    image_embeds = stable_diffusion.prepare_ip_adapter_image_embeds(
        ip_adapter_image=image,
        ip_adapter_image_embeds=None,
        device=None,
        num_images_per_prompt=1,
        do_classifier_free_guidance=True,
    )
    torch.save(image_embeds, "image_embeds.ipadpt")

    image_embeds = torch.load("image_embeds.ipadpt")
    images = stable_diffusion(
        prompt="a polar bear sitting in a chair drinking a milkshake",
        ip_adapter_image_embeds=image_embeds,
negative_prompt="deformed, ugly, wrong proportion, low res, bad anatomy,
worst quality, low quality",
        num_inference_steps=100,
        generator=generator,
    ).images[0]

    image.save("polar_bear.png")
    ```

### Next steps

* Support multiple IP adapters
* Ensure it works for sdxl
* Extend the support for diffusion transformers
* Documentation along with refactoring


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you make sure to update the documentation with your changes?
- [ ] Did you write any new necessary tests?
  • Loading branch information
JingyaHuang authored Feb 17, 2025
1 parent 2f06ce6 commit 6786b8c
Show file tree
Hide file tree
Showing 17 changed files with 642 additions and 261 deletions.
43 changes: 39 additions & 4 deletions optimum/commands/export/neuronx.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,19 +179,53 @@ def parse_args_neuronx(parser: "ArgumentParser"):
type=float,
help="List of scaling factors for the lora adapters.",
)
optional_group.add_argument(
"--output_attentions",
action="store_true",
help="Whether or not for the traced model to return the attentions tensors of all attention layers.",
)

# Diffusion Only
optional_group.add_argument(
"--controlnet_ids",
default=None,
nargs="*",
type=str,
help="List of model ids (eg. `thibaud/controlnet-openpose-sdxl-1.0`) of ControlNet models.",
)
optional_group.add_argument(
"--output_attentions",
action="store_true",
help="Whether or not for the traced model to return the attentions tensors of all attention layers.",
ip_adapter_group = parser.add_argument_group("IP adapters")
ip_adapter_group.add_argument(
"--ip_adapter_id",
default=None,
nargs="*",
type=str,
help=(
"Model ids (eg. `h94/IP-Adapter`) of IP-Adapter models hosted on the Hub or paths to local directories containing the IP-Adapter weights."
),
)
ip_adapter_group.add_argument(
"--ip_adapter_subfolder",
default=None,
nargs="*",
type=str,
help="The subfolder location of a model file within a larger model repository on the Hub or locally. If a list is passed, it should have the same length as `ip_adapter_weight_names`.",
)
ip_adapter_group.add_argument(
"--ip_adapter_weight_name",
default=None,
nargs="*",
type=str,
help="The name of the weight file to load. If a list is passed, it should have the same length as `ip_adapter_subfolders`.",
)
ip_adapter_group.add_argument(
"--ip_adapter_scale",
default=None,
nargs="*",
type=float,
help="Scaling factors for the IP-Adapters.",
)

# Static Input Shapes
input_group = parser.add_argument_group("Input shapes")
doc_input = "that the Neuronx-cc compiler exported model will be able to take as input."
input_group.add_argument(
Expand Down Expand Up @@ -262,6 +296,7 @@ def parse_args_neuronx(parser: "ArgumentParser"):
help=f"Audio tasks only. Audio sequence length {doc_input}",
)

# Optimization Level
level_group = parser.add_mutually_exclusive_group()
level_group.add_argument(
"-O1",
Expand Down
126 changes: 65 additions & 61 deletions optimum/exporters/neuron/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import inspect
import os
from argparse import ArgumentParser
from dataclasses import fields
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

Expand All @@ -36,6 +37,10 @@
DIFFUSION_MODEL_VAE_ENCODER_NAME,
ENCODER_NAME,
NEURON_FILE_NAME,
ImageEncoderArguments,
InputShapesArguments,
IPAdapterArguments,
LoRAAdapterArguments,
is_neuron_available,
is_neuronx_available,
is_transformers_neuronx_available,
Expand Down Expand Up @@ -278,6 +283,26 @@ def infer_stable_diffusion_shapes_from_diffusers(
"encoder_hidden_size": encoder_hidden_size,
}

# Image encoder
if getattr(model, "image_encoder", None):
input_shapes["image_encoder"] = {
"batch_size": input_shapes[unet_or_transformer_name]["batch_size"],
"num_channels": model.image_encoder.config.num_channels,
"width": model.image_encoder.config.image_size,
"height": model.image_encoder.config.image_size,
}
# IP-Adapter: add image_embeds as input for unet/transformer
# unet has `ip_adapter_image_embeds` with shape [batch_size, 1, (self.image_encoder.config.image_size//patch_size)**2+1, self.image_encoder.config.hidden_size] as input
if getattr(model.unet.config, "encoder_hid_dim_type", None) == "ip_image_proj":
input_shapes[unet_or_transformer_name]["image_encoder_shapes"] = ImageEncoderArguments(
sequence_length=model.image_encoder.vision_model.embeddings.position_embedding.weight.shape[0],
hidden_size=model.image_encoder.vision_model.embeddings.position_embedding.weight.shape[1],
projection_dim=getattr(model.image_encoder.config, "projection_dim", None),
)

# Format with `InputShapesArguments`
for sub_model_name in input_shapes.keys():
input_shapes[sub_model_name] = InputShapesArguments(**input_shapes[sub_model_name])
return input_shapes


Expand All @@ -294,11 +319,8 @@ def get_submodels_and_neuron_configs(
submodels: Optional[Dict[str, Union[Path, str]]] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
lora_model_ids: Optional[Union[str, List[str]]] = None,
lora_weight_names: Optional[Union[str, List[str]]] = None,
lora_adapter_names: Optional[Union[str, List[str]]] = None,
lora_scales: Optional[Union[float, List[float]]] = None,
controlnet_ids: Optional[Union[str, List[str]]] = None,
lora_args: Optional[LoRAAdapterArguments] = None,
):
is_encoder_decoder = (
getattr(model.config, "is_encoder_decoder", False) if isinstance(model.config, PretrainedConfig) else False
Expand All @@ -315,11 +337,8 @@ def get_submodels_and_neuron_configs(
dynamic_batch_size=dynamic_batch_size,
submodels=submodels,
output_hidden_states=output_hidden_states,
lora_model_ids=lora_model_ids,
lora_weight_names=lora_weight_names,
lora_adapter_names=lora_adapter_names,
lora_scales=lora_scales,
controlnet_ids=controlnet_ids,
lora_args=lora_args,
)
elif is_encoder_decoder:
optional_outputs = {"output_attentions": output_attentions, "output_hidden_states": output_hidden_states}
Expand All @@ -346,7 +365,10 @@ def get_submodels_and_neuron_configs(
library_name=library_name,
)
input_shapes = check_mandatory_input_shapes(neuron_config_constructor, task, input_shapes)
neuron_config = neuron_config_constructor(model.config, dynamic_batch_size=dynamic_batch_size, **input_shapes)
input_shapes = InputShapesArguments(**input_shapes)
neuron_config = neuron_config_constructor(
model.config, dynamic_batch_size=dynamic_batch_size, input_shapes=input_shapes
)
model_name = getattr(model, "name_or_path", None) or model_name_or_path
model_name = model_name.split("/")[-1] if model_name else model.config.model_type
output_model_names = {model_name: "model.neuron"}
Expand All @@ -355,38 +377,15 @@ def get_submodels_and_neuron_configs(
return models_and_neuron_configs, output_model_names


def _normalize_lora_params(lora_model_ids, lora_weight_names, lora_adapter_names, lora_scales):
if isinstance(lora_model_ids, str):
lora_model_ids = [
lora_model_ids,
]
if isinstance(lora_weight_names, str):
lora_weight_names = [
lora_weight_names,
]
if isinstance(lora_adapter_names, str):
lora_adapter_names = [
lora_adapter_names,
]
if isinstance(lora_scales, float):
lora_scales = [
lora_scales,
]
return lora_model_ids, lora_weight_names, lora_adapter_names, lora_scales


def _get_submodels_and_neuron_configs_for_stable_diffusion(
model: Union["PreTrainedModel", "DiffusionPipeline"],
input_shapes: Dict[str, int],
output: Path,
dynamic_batch_size: bool = False,
submodels: Optional[Dict[str, Union[Path, str]]] = None,
output_hidden_states: bool = False,
lora_model_ids: Optional[Union[str, List[str]]] = None,
lora_weight_names: Optional[Union[str, List[str]]] = None,
lora_adapter_names: Optional[Union[str, List[str]]] = None,
lora_scales: Optional[Union[float, List[float]]] = None,
controlnet_ids: Optional[Union[str, List[str]]] = None,
lora_args: Optional[LoRAAdapterArguments] = None,
):
check_compiler_compatibility_for_stable_diffusion()
model = replace_stable_diffusion_submodels(model, submodels)
Expand All @@ -412,24 +411,19 @@ def _get_submodels_and_neuron_configs_for_stable_diffusion(
model.feature_extractor.save_pretrained(output.joinpath("feature_extractor"))
model.save_config(output)

lora_model_ids, lora_weight_names, lora_adapter_names, lora_scales = _normalize_lora_params(
lora_model_ids, lora_weight_names, lora_adapter_names, lora_scales
)
models_and_neuron_configs = get_diffusion_models_for_export(
pipeline=model,
text_encoder_input_shapes=input_shapes["text_encoder"],
unet_input_shapes=input_shapes.get("unet", None),
transformer_input_shapes=input_shapes.get("transformer", None),
vae_encoder_input_shapes=input_shapes["vae_encoder"],
vae_decoder_input_shapes=input_shapes["vae_decoder"],
lora_args=lora_args,
dynamic_batch_size=dynamic_batch_size,
output_hidden_states=output_hidden_states,
lora_model_ids=lora_model_ids,
lora_weight_names=lora_weight_names,
lora_adapter_names=lora_adapter_names,
lora_scales=lora_scales,
controlnet_ids=controlnet_ids,
controlnet_input_shapes=input_shapes.get("controlnet", None),
image_encoder_input_shapes=input_shapes.get("image_encoder", None),
)
output_model_names = {
DIFFUSION_MODEL_VAE_ENCODER_NAME: os.path.join(DIFFUSION_MODEL_VAE_ENCODER_NAME, NEURON_FILE_NAME),
Expand All @@ -449,6 +443,8 @@ def _get_submodels_and_neuron_configs_for_stable_diffusion(
output_model_names[DIFFUSION_MODEL_TRANSFORMER_NAME] = os.path.join(
DIFFUSION_MODEL_TRANSFORMER_NAME, NEURON_FILE_NAME
)
if getattr(model, "image_encoder", None) is not None:
output_model_names["image_encoder"] = os.path.join("image_encoder", NEURON_FILE_NAME)

# ControlNet models
if controlnet_ids:
Expand Down Expand Up @@ -515,13 +511,11 @@ def load_models_and_neuron_configs(
local_files_only: bool,
token: Optional[Union[bool, str]],
submodels: Optional[Dict[str, Union[Path, str]]],
lora_model_ids: Optional[Union[str, List[str]]],
lora_weight_names: Optional[Union[str, List[str]]],
lora_adapter_names: Optional[Union[str, List[str]]],
lora_scales: Optional[Union[float, List[float]]],
torch_dtype: Optional[Union[str, torch.dtype]] = None,
tensor_parallel_size: int = 1,
controlnet_ids: Optional[Union[str, List[str]]] = None,
lora_args: Optional[LoRAAdapterArguments] = None,
ip_adapter_args: Optional[IPAdapterArguments] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
**input_shapes,
Expand All @@ -542,6 +536,14 @@ def load_models_and_neuron_configs(
}
if model is None:
model = TasksManager.get_model_from_task(**model_kwargs)
# Load IP-Adapter if it exists
if ip_adapter_args is not None and not all(
getattr(ip_adapter_args, field.name) is None for field in fields(ip_adapter_args)
):
model.load_ip_adapter(
ip_adapter_args.model_id, subfolder=ip_adapter_args.subfolder, weight_name=ip_adapter_args.weight_name
)
model.set_ip_adapter_scale(scale=ip_adapter_args.scale)

models_and_neuron_configs, output_model_names = get_submodels_and_neuron_configs(
model=model,
Expand All @@ -556,11 +558,8 @@ def load_models_and_neuron_configs(
submodels=submodels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
lora_model_ids=lora_model_ids,
lora_weight_names=lora_weight_names,
lora_adapter_names=lora_adapter_names,
lora_scales=lora_scales,
controlnet_ids=controlnet_ids,
lora_args=lora_args,
)

return models_and_neuron_configs, output_model_names
Expand Down Expand Up @@ -592,11 +591,9 @@ def main_export(
output_attentions: bool = False,
output_hidden_states: bool = False,
library_name: Optional[str] = None,
lora_model_ids: Optional[Union[str, List[str]]] = None,
lora_weight_names: Optional[Union[str, List[str]]] = None,
lora_adapter_names: Optional[Union[str, List[str]]] = None,
lora_scales: Optional[Union[float, List[float]]] = None,
controlnet_ids: Optional[Union[str, List[str]]] = None,
lora_args: Optional[LoRAAdapterArguments] = None,
ip_adapter_args: Optional[IPAdapterArguments] = None,
**input_shapes,
):
output = Path(output)
Expand Down Expand Up @@ -627,20 +624,17 @@ def main_export(
local_files_only=local_files_only,
token=token,
submodels=submodels,
lora_args=lora_args,
ip_adapter_args=ip_adapter_args,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
lora_model_ids=lora_model_ids,
lora_weight_names=lora_weight_names,
lora_adapter_names=lora_adapter_names,
lora_scales=lora_scales,
controlnet_ids=controlnet_ids,
**input_shapes,
)

_, neuron_outputs = export_models(
models_and_neuron_configs=models_and_neuron_configs,
output_dir=output,
torch_dtype=torch_dtype,
disable_neuron_cache=disable_neuron_cache,
compiler_workdir=compiler_workdir,
inline_weights_to_neff=inline_weights_to_neff,
Expand Down Expand Up @@ -752,6 +746,18 @@ def main():
compiler_kwargs = infer_compiler_kwargs(args)
optional_outputs = customize_optional_outputs(args)
optlevel = parse_optlevel(args)
lora_args = LoRAAdapterArguments(
model_ids=getattr(args, "lora_model_ids", None),
weight_names=getattr(args, "lora_weight_names", None),
adapter_names=getattr(args, "lora_adapter_names", None),
scales=getattr(args, "lora_scales", None),
)
ip_adapter_args = IPAdapterArguments(
model_id=getattr(args, "ip_adapter_id", None),
subfolder=getattr(args, "ip_adapter_subfolder", None),
weight_name=getattr(args, "ip_adapter_weight_name", None),
scale=getattr(args, "ip_adapter_scale", None),
)

main_export(
model_name_or_path=args.model,
Expand All @@ -772,11 +778,9 @@ def main():
do_validation=not args.disable_validation,
submodels=submodels,
library_name=library_name,
lora_model_ids=getattr(args, "lora_model_ids", None),
lora_weight_names=getattr(args, "lora_weight_names", None),
lora_adapter_names=getattr(args, "lora_adapter_names", None),
lora_scales=getattr(args, "lora_scales", None),
controlnet_ids=getattr(args, "controlnet_ids", None),
lora_args=lora_args,
ip_adapter_args=ip_adapter_args,
**optional_outputs,
**input_shapes,
)
Expand Down
Loading

0 comments on commit 6786b8c

Please sign in to comment.