Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix bug of dual module (setattr, and compatible with DualModule input) #613

Merged
merged 29 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
d78c84f
support text encoder
marigoold Jan 30, 2024
4e56a68
refine code
marigoold Jan 30, 2024
46a6a3f
compatible for prev diffusers version
marigoold Jan 30, 2024
ffdf6ed
Update __init__.py
marigoold Jan 30, 2024
9877fb1
refine
marigoold Jan 30, 2024
9aa6af4
Merge branch 'dev_wy_lora_support_textencoder' of github.com:Oneflow-…
marigoold Jan 30, 2024
e252c38
update readme
marigoold Jan 31, 2024
7ede9af
refine doc
marigoold Jan 31, 2024
ef1eb8c
remove unfuse in fuse func
marigoold Jan 31, 2024
cead728
Merge branch 'main' into dev_wy_lora_support_textencoder
marigoold Jan 31, 2024
54903ce
refine
marigoold Jan 31, 2024
55c65b9
rename
marigoold Jan 31, 2024
7892a5b
remove out dated lora.py
marigoold Jan 31, 2024
93388a4
update readme
marigoold Jan 31, 2024
cb05e3e
refine
marigoold Jan 31, 2024
a964b23
Update lora.py
marigoold Jan 31, 2024
46b54c2
fix bug
marigoold Jan 31, 2024
2cf586e
Merge branch 'dev_wy_lora_support_textencoder' of github.com:Oneflow-…
marigoold Jan 31, 2024
8743725
refine
marigoold Jan 31, 2024
cf0452f
Merge branch 'main' into dev_wy_lora_support_textencoder
marigoold Feb 1, 2024
628a608
dual modulelist setattr fix bug, compatible with DualModule input
marigoold Feb 1, 2024
c773ac4
remove utils/__init__.py
marigoold Feb 2, 2024
6009d1a
modify examples
marigoold Feb 2, 2024
4a43b0b
update doc, and var name
marigoold Feb 2, 2024
57e8ca1
Merge branch 'dev_wy_lora_support_textencoder' into fix_wy_dualmodule…
marigoold Feb 2, 2024
bf57ec6
compatible for PEFT
marigoold Feb 2, 2024
0e7b951
Merge branch 'main' into fix_wy_dualmodulelist_setattr
marigoold Feb 3, 2024
e2c3991
Merge branch 'main' into fix_wy_dualmodulelist_setattr
strint Feb 4, 2024
573a1e5
Merge branch 'main' into fix_wy_dualmodulelist_setattr
strint Feb 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/text_to_image_sdxl_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from onediff.infer_compiler.utils import TensorInplaceAssign

try:
from onediffx.utils.lora import load_and_fuse_lora, unfuse_lora
from onediffx.lora import load_and_fuse_lora, unfuse_lora
except ImportError:
raise RuntimeError("OneDiff onediffx is not installed. Please check onediff_diffusers_extensions/README.md to install onediffx.")

Expand Down Expand Up @@ -93,7 +93,7 @@

# 4. unfuse_lora can uninstall LoRA weights and restore the weights of UNet
generator = torch.manual_seed(0)
unfuse_lora(pipe.unet)
unfuse_lora(pipe)
images_fusion = pipe(
"masterpiece, best quality, mountain",
generator=generator,
Expand Down
80 changes: 67 additions & 13 deletions onediff_diffusers_extensions/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ OneDiffX is a OneDiff Extension for HF diffusers. It provides some acceleration
- [DeepCache Speedup](#deepcache-speedup)
- [Stable Diffusion XL](#run-stable-diffusion-xl-with-onediffx)
- [Stable Diffusion 1.5](#run-stable-diffusion-15-with-onediffx)
- [LoRA loading and switching speed up](#lora-loading-and-switching-speed-up)
- [Fast LoRA loading and switching](#fast-lora-loading-and-switching)
- [Quantization](#quantization)
- [Contact](#contact)

Expand Down Expand Up @@ -150,9 +150,42 @@ deepcache_output = pipe(
export_to_video(deepcache_output, "generated.mp4", fps=7)
```

## LoRA loading and switching speed up

OneDiff provides a faster implementation of loading LoRA, by invoking `onediffx.utils.lora.load_and_fuse_lora` you can load and fuse LoRA to pipeline.
## Fast LoRA loading and switching

OneDiff provides a more efficient implementation of loading LoRA, by invoking `load_and_fuse_lora` you can load and fuse LoRA to pipeline, and by invoking `unfuse_lora` you can restore the weight of base model.

### API
`onediffx.utils.lora.load_and_fuse_lora(pipeline: LoraLoaderMixin, pretrained_model_name_or_path_or_dict: Union[str, Path, Dict[str, torch.Tensor]], adapter_name: Optional[str] = None, *, lora_scale: float = 1.0, offload_device="cpu", offload_weight="lora", use_cache=False, **kwargs)`:
- pipeline (`LoraLoaderMixin`): The pipeline that will load and fuse LoRA weight.

- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): Can be either:

- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on the Hub.

- A path to a *directory* containing the model weights saved with [ModelMixin.save_pretrained()](https://huggingface.co/docs/diffusers/v0.25.1/en/api/models/overview#diffusers.ModelMixin.save_pretrained).

- A [torch state dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).

- adapter_name(`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. **Not supported now**.

- lora_scale (`float`, defaults to 1.0): Controls how much to influence the outputs with the LoRA parameters.

- offload_device (`str`, must be one of "cpu" and "cuda"): The device to offload the weight of LoRA or model

- offload_weight (`str`, must be one of "lora" and "weight"): The weight type to offload. If set to "lora", the weight of LoRA will be offloaded to `offload_device`, and if set to "weight", the weight of Linear or Conv2d will be offloaded.

- use_cache (`bool`, optional): Whether to save LoRA to cache. If set to True, loaded LoRA will be cached in memory.

- kwargs(`dict`, *optional*) — See [lora_state_dict()](https://huggingface.co/docs/diffusers/v0.25.1/en/api/loaders/lora#diffusers.loaders.LoraLoaderMixin.lora_state_dict)



`onediffx.utils.lora.unfuse_lora(pipeline: LoraLoaderMixin) -> None`:

- pipeline (`LoraLoaderMixin`): The pipeline that will unfuse LoRA weight.

### Example

```python
import torch
Expand All @@ -161,9 +194,7 @@ from onediffx import compile_pipe
from onediffx.utils.lora import load_and_fuse_lora, unfuse_lora

MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0"
pipe = DiffusionPipeline.from_pretrained(
MODEL_ID, variant="fp16", torch_dtype=torch.float16
).to("cuda")
pipe = DiffusionPipeline.from_pretrained(MODEL_ID, variant="fp16", torch_dtype=torch.float16).to("cuda")

LORA_MODEL_ID = "hf-internal-testing/sdxl-1.0-lora"
LORA_FILENAME = "sd_xl_offset_example-lora_1.0.safetensors"
Expand All @@ -179,17 +210,40 @@ images_fusion = pipe(
num_inference_steps=30,
).images[0]
images_fusion.save("test_sdxl_lora.png")

# before loading another LoRA, you need to
# unload LoRA weights and restore base model
unfuse_lora(pipe)
load_and_fuse_lora(pipe, LORA_MODEL_ID, weight_name=LORA_FILENAME, lora_scale=1.0)
```

We compared different methods of loading LoRA. The comparison of loading LoRA once is as shown in the table below.
### Benchmark

We choose 5 LoRAs to profile loading and switching speed of 3 different APIs

1. `load_lora_weight`, which has high loading performance but low inference performance

2. `load_lora_weight + fuse_lora`, which has high inference performance but low loading performance

3. `onediffx.utils.lora.load_and_fuse_lora`, which has high loading performance and high inference performance


The results are shown below

| LoRA name | size | load_lora_weight | load_lora_weight + fuse_lora | **onediffx load_and_fuse_lora** | unet cnt | te1 cnt | te2 cnt | src link |
|------------------------------------------|-------|-------------------|-----------------------------|----------------------------------|----------|---------|---------|-----------------------------------------------|
| SDXL-Emoji-Lora-r4.safetensors | 28M | 1.69 s | 2.34 s | **0.78 s** | 2166 | 216 | 576 | [Link](https://novita.ai/model/SDXL-Emoji-Lora-r4_160282) |
| sdxl_metal_lora.safetensors | 23M | 0.97 s | 1.73 s | **0.19 s** | 1120 | 0 | 0 | |
| simple_drawing_xl_b1-000012.safetensors | 55M | 1.67 s | 2.57 s | **0.77 s** | 2166 | 216 | 576 | [Link](https://civitai.com/models/177820/sdxl-simple-drawing) |
| texta.safetensors | 270M | 1.72 s | 2.86 s | **0.97 s** | 2364 | 0 | 0 | [Link](https://civitai.com/models/221240/texta-generate-text-with-sdxl) |
| watercolor_v1_sdxl_lora.safetensors | 12M | 1.54 s | 2.01 s | **0.35 s** | 1680 | 0 | 0 | |

### Note

1. OneDiff extensions for LoRA is currently not supported for PEFT, and only supports diffusers of at least version 0.21.0.

| Method | Speed | Inference speed | LoRA loading speed |
|----------------------------------|-------|------------------|-----------------------|
| load_lora_weight | 1.10s | low | high |
| load_lora_weight + fuse_lora | 1.38s | high | low |
| onediff load_and_fuse_lora | 0.56s | **high** | **high** |
2. Diffusers (without PEFT) are limited to loading only one LoRA. Consequently, onediffx is also restricted to loading a single LoRA. We are currently developing onediffx that are compatible with PEFT, enabling onediffx to load multiple LoRAs.

If you want to unload LoRA and then load a new LoRA, you only need to call `load_and_fuse_lora` again. There is no need to manually call `unfuse_lora`, cause it will be called implicitly in `load_and_fuse_lora`. You can also manually call `unfuse_lora` to restore the model's weights.

## Quantization

Expand Down
Loading
Loading