Skip to content

Commit

Permalink
onediffx supports lycoris (#817)
Browse files Browse the repository at this point in the history
  • Loading branch information
marigoold authored Apr 20, 2024
1 parent b6fd9f0 commit b9d0ae8
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 26 deletions.
24 changes: 12 additions & 12 deletions onediff_diffusers_extensions/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -279,14 +279,24 @@ Set the LoRA layers of `adapter_name` for the unet and text-encoder(s) with rela
- adapter_names(`str` or `List[str]`): The adapter name(s) of LoRA(s) to be set for the pipeline, must appear in the `adapter_name` parameter of the `load_and_fuse_lora` function, otherwise it will be ignored.
- adapter_weights(`float` or `List[float]`, optional): The weight(s) of adapter(s), if is None, it will be set to 1.0.

#### `onediffx.lora.delete_adapters``
#### `onediffx.lora.delete_adapters`

`onediffx.lora.delete_adapters(pipeline: LoraLoaderMixin, adapter_names: Union[List[str], str])`

Deletes the LoRA layers of `adapter_name` for the unet and text-encoder(s).

- adapter_names (`str` or `List[str]`): The names of the adapter to delete. Can be a single string or a list of strings

#### `onediffx.lora.update_graph_with_constant_folding_info`

`onediffx.lora.update_graph_with_constant_folding_info(module: torch.nn.Module, info: Dict[str, flow.Tensor] = None)`

Update the weights of graph after loading LoRA. (If OneDiff has enabled constant folding optimization during compilation, some parameters in the static graph may not be updated correctly after loading lora. Invoke this function manually to update the weights of the static graph correctly.)

Check [text_to_image_sdxl_lora.py](./examples/text_to_image_sdxl_lora.py) for more details.

> **Note**: If you are using onediffx instead of diffusers and PEFT to load LoRA, there is no need to call this function, as onediffx will handle all the necessary work.
### Example

```python
Expand Down Expand Up @@ -413,16 +423,6 @@ We tested the performance of `set_adapters`, still using the five LoRA models me

1. OneDiff extensions for LoRA is currently only supported for limited PEFT APIs, and only supports diffusers of at least version 0.21.0.

2. If your LoRA model only contains the weights of the Linear module, you can directly use OneDiffX without any modifications. But if your LoRA model includes the weights of the Conv module (such as LyCORIS), you need to disable constant folding optimization by above methods (which may cause a performance drop of around 4.4%), otherwise the weights of the Conv module may not be loaded into the model.
- Set the env var `ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION` to 0
- Set compiler_config.mlir_enable_inference_optimization to 0 before invoking `oneflow_compile` as the code below
```
from onediffx import compiler_config
compiler_config.mlir_enable_inference_optimization = 0
...
pipe.unet = oneflow_compile(pipe.unet)
...
```
### Optimization
- When not using the PEFT backend, diffusers will replace the module corresponding to LoRA with the LoRACompatible module, incurring additional parameter initialization time overhead. In OneDiffX, the LoRA parameters are directly fused into the model, bypassing the step of replacing the module, thereby reducing the time overhead.

Expand Down Expand Up @@ -453,7 +453,7 @@ new_base.unet = compiled_unet
new_base(prompt)
```

> Note: Please make sure that your PyTorch version is **at least 2.1.0**, and set the environment variable `ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION` to **0**. And the feature is not supported for quantized model.
> Note: The feature is not supported for quantized model.

## Quantization
Expand Down
30 changes: 16 additions & 14 deletions onediff_diffusers_extensions/examples/text_to_image_sdxl_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from onediff.infer_compiler.utils import TensorInplaceAssign

try:
from onediffx.lora import load_and_fuse_lora, unfuse_lora
from onediffx.lora import load_and_fuse_lora, unfuse_lora, update_graph_with_constant_folding_info
except ImportError:
raise RuntimeError(
"OneDiff onediffx is not installed. Please check onediff_diffusers_extensions/README.md to install onediffx."
Expand All @@ -19,7 +19,7 @@
LORA_FILENAME = "sd_xl_offset_example-lora_1.0.safetensors"

pipe.unet = oneflow_compile(pipe.unet)
generator = torch.manual_seed(0)
latents = torch.randn(1, 4, 128, 128, generator=torch.cuda.manual_seed(0), dtype=torch.float16, device="cuda")

# There are three methods to load LoRA into OneDiff compiled model
# 1. pipe.load_lora_weights (Low Performence)
Expand All @@ -33,24 +33,24 @@
pipe.load_lora_weights(LORA_MODEL_ID, weight_name=LORA_FILENAME)
images_fusion = pipe(
"masterpiece, best quality, mountain",
generator=generator,
generator=torch.manual_seed(0),
height=1024,
width=1024,
num_inference_steps=30,
latents=latents,
).images[0]
images_fusion.save("test_sdxl_lora_method1.png")
pipe.unload_lora_weights()


# need to rebuild UNet because method 1 has different computer graph with naive UNet
generator = torch.manual_seed(0)
pipe = DiffusionPipeline.from_pretrained(
MODEL_ID, variant="fp16", torch_dtype=torch.float16
).to("cuda")
pipe.unet = oneflow_compile(pipe.unet)
images_fusion = pipe(
"masterpiece, best quality, mountain",
generator=generator,
generator=torch.manual_seed(0),
height=1024,
width=1024,
num_inference_steps=30,
Expand All @@ -59,18 +59,21 @@

# 2. pipe.load_lora_weights + TensorInplaceAssign + pipe.fuse_lora (Deprecated)
# The 'fuse_lora' API is not available in diffuser versions prior to 0.21.0.
generator = torch.manual_seed(0)
pipe.load_lora_weights(LORA_MODEL_ID, weight_name=LORA_FILENAME)
if hasattr(pipe, "fuse_lora"):
# TensorInplaceAssign is DEPRECATED and NOT RECOMMENDED, please use onediff.utils.load_and_fuse_lora
with TensorInplaceAssign(pipe.unet):
pipe.fuse_lora(lora_scale=1.0)
# The function has to be invoked manually because of the constant folding optimization
# If you are using OneDiffx instead of diffusers to load LoRA, the function is not necessary
update_graph_with_constant_folding_info(pipe.unet)
images_fusion = pipe(
"masterpiece, best quality, mountain",
generator=generator,
generator=torch.manual_seed(0),
height=1024,
width=1024,
num_inference_steps=30,
latents=latents,
).images[0]
images_fusion.save("test_sdxl_lora_method2.png")

Expand All @@ -81,27 +84,26 @@

# 3. onediff.utils.load_and_fuse_lora (RECOMMENDED)
# load_and_fuse_lora is equivalent to load_lora_weights + fuse_lora
generator = torch.manual_seed(0)
load_and_fuse_lora(pipe, LORA_MODEL_ID, weight_name=LORA_FILENAME, lora_scale=1.0)
images_fusion = pipe(
"masterpiece, best quality, mountain",
generator=generator,
generator=torch.manual_seed(0),
height=1024,
width=1024,
num_inference_steps=30,
latents=latents,
).images[0]

unfuse_lora(pipe)
images_fusion.save("test_sdxl_lora_method3.png")


# 4. unfuse_lora can uninstall LoRA weights and restore the weights of UNet
generator = torch.manual_seed(0)
unfuse_lora(pipe)
images_fusion = pipe(
"masterpiece, best quality, mountain",
generator=generator,
generator=torch.manual_seed(0),
height=1024,
width=1024,
num_inference_steps=30,
latents=latents,
).images[0]

images_fusion.save("test_sdxl_lora_without_lora.png")
2 changes: 2 additions & 0 deletions onediff_diffusers_extensions/onediffx/lora/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@
delete_adapters,
get_active_adapters,
)

from onediff.infer_compiler.utils.param_utils import update_graph_with_constant_folding_info
5 changes: 5 additions & 0 deletions onediff_diffusers_extensions/onediffx/lora/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
else:
is_peft_available = lambda: False

from onediff.infer_compiler.utils.param_utils import update_graph_related_tensor

if version.parse(diffusers.__version__) <= version.parse("0.20.0"):
from diffusers.loaders import PatchedLoraProjection
else:
Expand Down Expand Up @@ -136,6 +138,7 @@ def _set_adapter(self, adapter_names, adapter_weights):
if delta_weight is not None:
fused_weight = self.weight.data.float() + delta_weight
self.weight.data.copy_(fused_weight.to(device=device, dtype=dtype))
update_graph_related_tensor(self)


def _delete_adapter(self, adapter_names):
Expand Down Expand Up @@ -224,6 +227,7 @@ def fuse_lora(
lora_weight = get_delta_weight(self, w_up, w_down, 1.0)
fused_weight = self.weight.data.float() + lora_weight
self.weight.data.copy_(fused_weight.to(device=device, dtype=dtype))
update_graph_related_tensor(self)


def _unfuse_lora(
Expand Down Expand Up @@ -263,6 +267,7 @@ def _unfuse_lora(

if delta_weight is not None:
self.weight.data -= delta_weight
update_graph_related_tensor(self)


# the code is referenced from https://github.com/huggingface/diffusers/blob/ce9825b56bd8a6849e68b9590022e935400659e6/src/diffusers/loaders/lora_conversion_utils.py#L24
Expand Down
2 changes: 2 additions & 0 deletions src/onediff/infer_compiler/utils/param_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def forward_generate_constant_folding_info_hook(module, args, output):
if getattr(module, CONSTANT_FOLDING_INFO_ATTR, None) is not None:
return

logger.info(f"generate constant folding info")
generate_constant_folding_info(module)


Expand All @@ -177,5 +178,6 @@ def forward_pre_check_and_update_state_hook(module, args):
if constant_folding_info is None:
return

logger.info(f"state_dict updated, modify the related weight in graph")
update_graph_with_constant_folding_info(module, constant_folding_info)
setattr(module._torch_module, STATE_UPDATED_ATTR, False)

0 comments on commit b9d0ae8

Please sign in to comment.