From f8484d1d5a4e85bb4838dd29c83cbce3baaa0c1d Mon Sep 17 00:00:00 2001 From: Wang Yi <53533850+marigoold@users.noreply.github.com> Date: Sat, 27 Jan 2024 00:17:56 +0800 Subject: [PATCH] add cached load_lora_weight (#524) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit add cache for loaded LoRAs based on diffusers load_lora_weights, to avoid time cost of loading the same LoRA from disk TODO: - [x] support local file cached - [x] support lora downloaded from hub cached - [x] support unfuse lora - [x] support custom offload - [x] profile diffusers 原来 load LoRA 的方法中,时间开销最大的地方是 LoRA module 的参数初始化,但这一步是在推理中不需要的,是一个主要的优化点。 这里在 examples/text_to_image_sdxl_lora.py 里面增加了多种使用 LoRA 的方法,分别是: 1. 只使用 load_lora_weights,这会改变 Linear forward 的计算路径,从而改变计算图。好处是不用 fuse LoRA,把 LoRA 的计算推迟到推理时,坏处就是推理性能下降 2. 使用 load_lora_weights 和 fuse LoRA 来加载 LoRA,好处是推理性能不变,坏处是加载 LoRA 需要一些时间 3. 本 PR 开发的 load_and_fuse_lora,可以在保证推理性能的前提下,尽可能减少加载、切换 LoRA 的开销。具体思路是增加一个 cache,保存 LoRA 的 cpu offload,下次导入的时候直接从内存中读取,减少磁盘读取的开销。另外手动重写了 fuse 过程,跳过了 LoRA module 参数初始化的过程,节省了大部分时间。 推理、加载速度 profile 结果(加载内存中的 LoRA dict): ```python /data/h/w/w/diffusers/examples dev_wy_cached_lora *15 !1 ?13 python3 text_to_image_sdxl_lora.py Loading pipeline components...: 100%|████████████████████████████████████| 7/7 [00:01<00:00, 5.57it/s] [1] Elapsed time: 0.9750442989170551 seconds 100%|██████████████████████████████████████████████████████████████████| 30/30 [01:08<00:00, 2.28s/it] 100%|██████████████████████████████████████████████████████████████████| 30/30 [00:04<00:00, 6.26it/s] You are using `unload_lora_weights` to disable and unload lora weights. If you want to iteratively enable and disable adapter weights,you can use `pipe.enable_lora()` or `pipe.disable_lora()`. After installing the latest version of PEFT. You are using `unload_lora_weights` to disable and unload lora weights. If you want to iteratively enable and disable adapter weights,you can use `pipe.enable_lora()` or `pipe.disable_lora()`. After installing the latest version of PEFT. Loading pipeline components...: 100%|████████████████████████████████████| 7/7 [00:01<00:00, 5.51it/s] 100%|██████████████████████████████████████████████████████████████████| 30/30 [00:39<00:00, 1.32s/it] [2] Elapsed time: 4.074353616917506 seconds 100%|██████████████████████████████████████████████████████████████████| 30/30 [00:04<00:00, 7.18it/s] You are using `unload_lora_weights` to disable and unload lora weights. If you want to iteratively enable and disable adapter weights,you can use `pipe.enable_lora()` or `pipe.disable_lora()`. After installing the latest version of PEFT. You are using `unload_lora_weights` to disable and unload lora weights. If you want to iteratively enable and disable adapter weights,you can use `pipe.enable_lora()` or `pipe.disable_lora()`. After installing the latest version of PEFT. [3] Elapsed time: 0.7907805619761348 seconds 100%|██████████████████████████████████████████████████████████████████| 30/30 [00:04<00:00, 7.16it/s] 100%|██████████████████████████████████████████████████████████████████| 30/30 [00:04<00:00, 7.14it/s] ``` 三种方法的时间分别为 1. 0.9750442989170551 seconds 2. 4.074353616917506 seconds 3. 0.7907805619761348 seconds 加载三个 LoRA 的速度(不跑推理,LoRA dict): ```python /data/h/w/w/diffusers/examples dev_wy_cached_lora *15 !1 ?13 python3 /data/home/wangyi/workspace/temp/test.py 1 х 8s py10 Py wangyi@oneflow-28 05:57:56 Loading pipeline components...: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:01<00:00, 5.38it/s] [1] Elapsed time: 3.8003906158264726 seconds [2] Elapsed time: 5.7611241028644145 seconds You are using `unload_lora_weights` to disable and unload lora weights. If you want to iteratively enable and disable adapter weights,you can use `pipe.enable_lora()` or `pipe.disable_lora()`. After installing the latest version of PEFT. You are using `unload_lora_weights` to disable and unload lora weights. If you want to iteratively enable and disable adapter weights,you can use `pipe.enable_lora()` or `pipe.disable_lora()`. After installing the latest version of PEFT. [3] Elapsed time: 2.2499090780038387 seconds ``` 三种方法的速度分别是: 1. 3.8003906158264726 seconds 2. 5.7611241028644145 seconds 3. 2.2499090780038387 seconds profile 了一下用时占比,可以看到用时从高到低是:getattr(DualModule 的设计问题),linear fuse,linear unfuse ``` Ordered by: cumulative time ncalls tottime percall cumtime percall filename:lineno(function) 1 0.258 0.258 1.390 1.390 /data/home/wangyi/workspace/diffusers/src/onediff/utils/lora.py:179(load_and_fuse_lora) 11999/7640 0.016 0.000 0.599 0.000 {built-in method builtins.getattr} 7996/4359 0.015 0.000 0.583 0.000 /data/home/wangyi/workspace/diffusers/src/onediff/infer_compiler/with_oneflow_compile.py:82(__getattr__) 2322 0.025 0.000 0.500 0.000 /data/home/wangyi/workspace/diffusers/src/onediff/infer_compiler/with_oneflow_compile.py:120(__init__) 722 0.058 0.000 0.322 0.000 /data/home/wangyi/workspace/diffusers/src/onediff/utils/lora.py:30(linear_fuse_lora) 11788 0.006 0.000 0.279 0.000 /data/home/wangyi/workspace/diffusers/src/onediff/infer_compiler/with_oneflow_compile.py:159(__init__) 11788 0.016 0.000 0.273 0.000 /data/home/wangyi/workspace/diffusers/src/onediff/infer_compiler/with_oneflow_compile.py:21(__init__) 1063466 0.160 0.000 0.160 0.000 {method 'replace' of 'str' objects} 11788 0.006 0.000 0.145 0.000 /data/home/wangyi/workspace/diffusers/src/onediff/infer_compiler/with_oneflow_compile.py:157(get_mixed_dual_module) 14110 0.136 0.000 0.145 0.000 /home/wangyi/miniconda3/envs/py10/lib/python3.10/site-packages/torch/nn/modules/module.py:437(__init__) 11788 0.134 0.000 0.139 0.000 {built-in method builtins.__build_class__} 23576 0.020 0.000 0.133 0.000 /data/home/wangyi/workspace/diffusers/src/onediff/infer_compiler/with_oneflow_compile.py:105(__setattr__) 25978 0.067 0.000 0.127 0.000 /home/wangyi/miniconda3/envs/py10/lib/python3.10/site-packages/torch/nn/modules/module.py:1617(__setattr__) 722 0.036 0.000 0.120 0.000 /data/home/wangyi/workspace/diffusers/src/onediff/utils/lora.py:75(linear_unfuse_lora) 1446/723 0.002 0.000 0.117 0.000 /data/home/wangyi/workspace/diffusers/src/onediff/infer_compiler/with_oneflow_compile.py:303(__getattr__) ``` --- examples/text_to_image_sdxl_lora.py | 88 +++- onediff_diffusers_extensions/README.md | 41 ++ .../diffusers_extensions/utils/lora.py | 452 ++++++++++++++++++ .../utils/model_inplace_assign.py | 3 + 4 files changed, 567 insertions(+), 17 deletions(-) create mode 100644 onediff_diffusers_extensions/diffusers_extensions/utils/lora.py diff --git a/examples/text_to_image_sdxl_lora.py b/examples/text_to_image_sdxl_lora.py index d9f5a39cb..13c86258e 100644 --- a/examples/text_to_image_sdxl_lora.py +++ b/examples/text_to_image_sdxl_lora.py @@ -1,11 +1,13 @@ import torch from pathlib import Path -from huggingface_hub import hf_hub_download from diffusers import DiffusionPipeline -from diffusers.utils import DIFFUSERS_CACHE from onediff.infer_compiler import oneflow_compile from onediff.infer_compiler.utils import TensorInplaceAssign +try: + from diffusers_extensions.utils.lora import load_and_fuse_lora, unfuse_lora +except ImportError: + raise RuntimeError("OneDiff diffusers_extensions is not installed. Please check onediff_diffusers_extensions/README.md to install diffusers_extensions.") MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0" pipe = DiffusionPipeline.from_pretrained( @@ -13,33 +15,85 @@ ).to("cuda") LORA_MODEL_ID = "hf-internal-testing/sdxl-1.0-lora" LORA_FILENAME = "sd_xl_offset_example-lora_1.0.safetensors" -lora_file = Path(DIFFUSERS_CACHE) / LORA_FILENAME -if not lora_file.exists(): - hf_hub_download( - repo_id=LORA_MODEL_ID, - filename=LORA_FILENAME, - local_dir=DIFFUSERS_CACHE, - ) pipe.unet = oneflow_compile(pipe.unet) -pipe.load_lora_weights(lora_file) generator = torch.manual_seed(0) +# There are three methods to load LoRA into OneDiff compiled model +# 1. pipe.load_lora_weights (Low Performence) +# 2. pipe.load_lora_weights + TensorInplaceAssign + pipe.fuse_lora (Deprecated) +# 3. onediff.utils.load_and_fuse_lora (RECOMMENDED) + + +# 1. pipe.load_lora_weights (Low Performence) +# use load_lora_weights without fuse_lora is not recommended, +# due to the disruption of attention optimization, the inference speed is slowed down +pipe.load_lora_weights(LORA_MODEL_ID, weight_name=LORA_FILENAME) +images_fusion = pipe( + "masterpiece, best quality, mountain", + generator=generator, + height=1024, + width=1024, + num_inference_steps=30, +).images[0] +images_fusion.save("test_sdxl_lora_method1.png") +pipe.unload_lora_weights() + + +# need to rebuild UNet because method 1 has different computer graph with naive UNet +generator = torch.manual_seed(0) +pipe = DiffusionPipeline.from_pretrained( + MODEL_ID, variant="fp16", torch_dtype=torch.float16 +).to("cuda") +pipe.unet = oneflow_compile(pipe.unet) +images_fusion = pipe( + "masterpiece, best quality, mountain", + generator=generator, + height=1024, + width=1024, + num_inference_steps=30, +).images[0] + + +# 2. pipe.load_lora_weights + TensorInplaceAssign + pipe.fuse_lora (Deprecated) # The 'fuse_lora' API is not available in diffuser versions prior to 0.21.0. +generator = torch.manual_seed(0) +pipe.load_lora_weights(LORA_MODEL_ID, weight_name=LORA_FILENAME) if hasattr(pipe, "fuse_lora"): + # TensorInplaceAssign is DEPRECATED and NOT RECOMMENDED, please use onediff.utils.load_and_fuse_lora with TensorInplaceAssign(pipe.unet): pipe.fuse_lora(lora_scale=1.0) + images_fusion = pipe( + "masterpiece, best quality, mountain", + generator=generator, + height=1024, + width=1024, + num_inference_steps=30, + ).images[0] + images_fusion.save("test_sdxl_lora_method2.png") -if hasattr(pipe, "unfuse_lora"): with TensorInplaceAssign(pipe.unet): pipe.unfuse_lora() +pipe.unload_lora_weights() -# load LoRA twice to for checking result consistency -pipe.load_lora_weights(lora_file) -if hasattr(pipe, "fuse_lora"): - with TensorInplaceAssign(pipe.unet): - pipe.fuse_lora(lora_scale=1.0) +# 3. onediff.utils.load_and_fuse_lora (RECOMMENDED) +# load_and_fuse_lora is equivalent to load_lora_weights + fuse_lora +generator = torch.manual_seed(0) +load_and_fuse_lora(pipe, LORA_MODEL_ID, weight_name=LORA_FILENAME, lora_scale=1.0) +images_fusion = pipe( + "masterpiece, best quality, mountain", + generator=generator, + height=1024, + width=1024, + num_inference_steps=30, +).images[0] + +images_fusion.save("test_sdxl_lora_method3.png") + +# 4. unfuse_lora can uninstall LoRA weights and restore the weights of UNet +generator = torch.manual_seed(0) +unfuse_lora(pipe.unet) images_fusion = pipe( "masterpiece, best quality, mountain", generator=generator, @@ -48,4 +102,4 @@ num_inference_steps=30, ).images[0] -images_fusion.save("test_sdxl_lora.png") +images_fusion.save("test_sdxl_lora_without_lora.png") diff --git a/onediff_diffusers_extensions/README.md b/onediff_diffusers_extensions/README.md index 5c90d028b..975f04f8a 100644 --- a/onediff_diffusers_extensions/README.md +++ b/onediff_diffusers_extensions/README.md @@ -101,6 +101,47 @@ OneDiff Enterprise offers a quantization method that reduces memory usage, incre If you possess a OneDiff Enterprise license key, you can access instructions on OneDiff quantization and related models by visiting [Hugginface/siliconflow](https://huggingface.co/siliconflow). Alternatively, you can [contact](#contact) us to inquire about purchasing the OneDiff Enterprise license. +## LoRA loading and switching speed up + +OneDiff provides a faster implementation of loading LoRA, by invoking `diffusers_extensions.utils.lora.load_and_fuse_lora` you can load and fuse LoRA to pipeline. + +```python +import torch +from diffusers import DiffusionPipeline +from onediff.infer_compiler import oneflow_compile +from diffusers_extensions.utils.lora import load_and_fuse_lora, unfuse_lora + +MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0" +pipe = DiffusionPipeline.from_pretrained( + MODEL_ID, variant="fp16", torch_dtype=torch.float16 +).to("cuda") + +LORA_MODEL_ID = "hf-internal-testing/sdxl-1.0-lora" +LORA_FILENAME = "sd_xl_offset_example-lora_1.0.safetensors" + +pipe.unet = oneflow_compile(pipe.unet) + +# use onediff load_and_fuse_lora +load_and_fuse_lora(pipe, LORA_MODEL_ID, weight_name=LORA_FILENAME, lora_scale=1.0) +images_fusion = pipe( + "masterpiece, best quality, mountain", + height=1024, + width=1024, + num_inference_steps=30, +).images[0] +images_fusion.save("test_sdxl_lora.png") +``` + +We compared different methods of loading LoRA. The comparison of loading LoRA once is as shown in the table below. + +| Method | Speed | Inference speed | LoRA loading speed | +|----------------------------------|-------|------------------|-----------------------| +| load_lora_weight | 1.10s | low | high | +| load_lora_weight + fuse_lora | 1.38s | high | low | +| onediff load_and_fuse_lora | 0.56s | **high** | **high** | + +If you want to unload LoRA and then load a new LoRA, you only need to call `load_and_fuse_lora` again. There is no need to manually call `unfuse_lora`, cause it will be called implicitly in `load_and_fuse_lora`. You can also manually call `unfuse_lora` to restore the model's weights. + ## Contact For users of OneDiff Community, please visit [GitHub Issues](https://github.com/siliconflow/onediff/issues) for bug reports and feature requests. diff --git a/onediff_diffusers_extensions/diffusers_extensions/utils/lora.py b/onediff_diffusers_extensions/diffusers_extensions/utils/lora.py new file mode 100644 index 000000000..4897ec945 --- /dev/null +++ b/onediff_diffusers_extensions/diffusers_extensions/utils/lora.py @@ -0,0 +1,452 @@ +from pathlib import Path +from typing import Optional, Union, Dict, Tuple +from collections import OrderedDict, defaultdict + +import torch + +from onediff.infer_compiler.utils.log_utils import logger +from onediff.infer_compiler.with_oneflow_compile import DualModule + +from diffusers.loaders.lora import LoraLoaderMixin +from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear +from diffusers.utils import is_accelerate_available + +if is_accelerate_available(): + from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module + + +USE_PEFT_BACKEND = False + + +def offload_tensor(tensor, device): + cur_device = tensor.device + if cur_device == device: + return tensor.clone() + else: + return tensor.to(device) + + +def linear_fuse_lora( + self: torch.nn.Linear, + state_dict: Dict[str, torch.Tensor], + lora_scale: float = 1.0, + alpha: float = None, + rank: float = None, + *, + offload_device="cpu", + offload_weight="lora", +): + assert isinstance(self, torch.nn.Linear) + if isinstance(self, DualModule): + self = self._torch_module + + linear_unfuse_lora(self) + dtype, device = self.weight.data.dtype, self.weight.data.device + + w_down = state_dict["lora.down.weight"].float().to(device) + w_up = state_dict["lora.up.weight"].float().to(device) + + if alpha is not None: + w_up = w_up * (alpha / rank * lora_scale) + + if offload_weight == "lora": + self.register_buffer("_lora_up", offload_tensor(w_up, offload_device)) + self.register_buffer( + "_lora_down", offload_tensor(state_dict["lora.down.weight"], offload_device) + ) + self._lora_scale = lora_scale + + elif offload_weight == "weight": + self.register_buffer( + "_lora_orig_weight", offload_tensor(self.weight.data, offload_device) + ) + + else: + raise ValueError( + f"[OneDiff linear_fuse_lora] Invalid offload weight: {offload_weight}" + ) + + lora_weight = torch.bmm(w_up[None, :], w_down[None, :])[0] + fused_weight = self.weight.data.float() + lora_weight + self.weight.data.copy_(fused_weight.to(device=device, dtype=dtype)) + + +def linear_unfuse_lora(self: torch.nn.Linear): + assert isinstance(self, torch.nn.Linear) + + fused_weight = self.weight.data + dtype, device = fused_weight.dtype, fused_weight.device + + if ( + "_lora_orig_weight" in self._buffers + and self.get_buffer("_lora_orig_weight") is not None + ): + unfused_weight = self._lora_orig_weight + self._lora_orig_weight = None + + elif "_lora_up" in self._buffers and self.get_buffer("_lora_up") is not None: + w_up = self.get_buffer("_lora_up").to(device=device).float() + w_down = self.get_buffer("_lora_down").to(device).float() + + unfused_weight = self.weight.data.float() - ( + torch.bmm(w_up[None, :], w_down[None, :])[0] + ) + self._lora_up = None + self._lora_down = None + self._lora_scale = None + + else: + return + + self.weight.data.copy_(unfused_weight.to(device=device, dtype=dtype)) + + +def conv_fuse_lora( + self: torch.nn.Conv2d, + state_dict: Dict[str, torch.Tensor], + lora_scale: float = 1.0, + alpha: float = None, + rank: float = None, + *, + offload_device="cpu", + offload_weight="lora", +) -> None: + assert isinstance(self, torch.nn.Conv2d) + if isinstance(self, DualModule): + self = self._torch_module + conv_unfuse_lora(self) + dtype, device = self.weight.data.dtype, self.weight.data.device + + w_down = state_dict["lora.down.weight"].float().to(device) + w_up = state_dict["lora.up.weight"].float().to(device) + + if alpha is not None: + w_up = w_up * (alpha / rank * lora_scale) + + if offload_weight == "lora": + self.register_buffer("_lora_up", offload_tensor(w_up, offload_device)) + self.register_buffer( + "_lora_down", offload_tensor(state_dict["lora.down.weight"], offload_device) + ) + self._lora_scale = lora_scale + elif offload_weight == "weight": + self.register_buffer( + "_lora_orig_weight", offload_tensor(self.weight.data, offload_device) + ) + else: + raise ValueError( + f"[OneDiff conv_fuse_lora] Invalid offload weight: {offload_weight}" + ) + + lora_weight = torch.mm(w_up.flatten(start_dim=1), w_down.flatten(start_dim=1)) + lora_weight = lora_weight.reshape((self.weight.shape)) + + fused_weight = self.weight.data.float() + lora_weight + self.weight.data.copy_(fused_weight.to(device=device, dtype=dtype)) + + +def conv_unfuse_lora(self: torch.nn.Conv2d): + assert isinstance(self, torch.nn.Conv2d) + + fused_weight = self.weight.data + dtype, device = fused_weight.dtype, fused_weight.device + + if ( + "_lora_orig_weight" in self._buffers + and self.get_buffer("_lora_orig_weight") is not None + ): + unfused_weight = self._lora_orig_weight + self._lora_orig_weight = None + + elif "_lora_up" in self._buffers and self.get_buffer("_lora_up") is not None: + w_up = self._lora_up.to(device=device).float() + w_down = self._lora_down.to(device).float() + + fusion = torch.mm(w_up.flatten(start_dim=1), w_down.flatten(start_dim=1)) + fusion = fusion.reshape((fused_weight.shape)) + unfused_weight = fused_weight.float() - fusion + + self._lora_up = None + self._lora_down = None + self._lora_scale = None + + else: + return + + self.weight.data.copy_(unfused_weight.to(device=device, dtype=dtype)) + +def load_and_fuse_lora( + pipeline: LoraLoaderMixin, + pretrained_model_name_or_path_or_dict: Union[str, Path, Dict[str, torch.Tensor]], + adapter_name: Optional[str] = None, + *, + lora_scale: float = 1.0, + offload_device="cpu", + offload_weight="lora", + use_cache=False, + **kwargs, +) -> None: + self = pipeline + if adapter_name is not None: + raise ValueError( + f"[OneDiff load_and_fuse_lora] adapter_name != None is not supported" + ) + + if use_cache: + state_dict, network_alphas = load_state_dict_cached( + pretrained_model_name_or_path_or_dict, + unet_config=self.unet.config, + **kwargs, + ) + else: + state_dict, network_alphas = LoraLoaderMixin.lora_state_dict( + pretrained_model_name_or_path_or_dict, + unet_config=self.unet.config, + **kwargs, + ) + + is_correct_format = all("lora" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("[OneDiff load_and_fuse_lora] Invalid LoRA checkpoint.") + + # load lora into unet + keys = list(state_dict.keys()) + cls = type(self) + + if all( + key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) + for key in keys + ): + # Load the layers corresponding to UNet. + logger.info(f"Loading {cls.unet_name}.") + + unet_keys = [k for k in keys if k.startswith(cls.unet_name)] + state_dict = { + k.replace(f"{cls.unet_name}.", ""): v + for k, v in state_dict.items() + if k in unet_keys + } + + if network_alphas is not None: + alpha_keys = [ + k for k in network_alphas.keys() if k.startswith(cls.unet_name) + ] + network_alphas = { + k.replace(f"{cls.unet_name}.", ""): v + for k, v in network_alphas.items() + if k in alpha_keys + } + + else: + # Otherwise, we're dealing with the old format. This means the `state_dict` should only + # contain the module names of the `unet` as its keys WITHOUT any prefix. + warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`." + logger.warn(warn_message) + + # unet.load_attn + _pipeline = kwargs.pop("_pipeline", None) + + is_network_alphas_none = network_alphas is None + + is_lora = ( + all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys()) + and not USE_PEFT_BACKEND + ) + is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys()) + if is_custom_diffusion: + raise ValueError( + "[OneDiff load_and_fuse_lora] custom diffusion is not supported now." + ) + + if is_lora: + # correct keys + state_dict, network_alphas = self.unet.convert_state_dict_legacy_attn_format( + state_dict, network_alphas + ) + + if network_alphas is not None: + network_alphas_keys = list(network_alphas.keys()) + used_network_alphas_keys = set() + + lora_grouped_dict = defaultdict(dict) + mapped_network_alphas = {} + + all_keys = list(state_dict.keys()) + for key in all_keys: + value = state_dict.pop(key) + attn_processor_key, sub_key = ( + ".".join(key.split(".")[:-3]), + ".".join(key.split(".")[-3:]), + ) + lora_grouped_dict[attn_processor_key][sub_key] = value + + # Create another `mapped_network_alphas` dictionary so that we can properly map them. + if network_alphas is not None: + for k in network_alphas_keys: + if k.replace(".alpha", "") in key: + mapped_network_alphas.update( + {attn_processor_key: network_alphas.get(k)} + ) + used_network_alphas_keys.add(k) + + if not is_network_alphas_none: + if len(set(network_alphas_keys) - used_network_alphas_keys) > 0: + raise ValueError( + f"[OneDiff load_and_fuse_lora] The `network_alphas` has to be empty at this point but has the following keys \n\n {', '.join(network_alphas.keys())}" + ) + + if len(state_dict) > 0: + raise ValueError( + f"[OneDiff load_and_fuse_lora] The `state_dict` has to be empty at this point but has the following keys \n\n {', '.join(state_dict.keys())}" + ) + + for key, value_dict in lora_grouped_dict.items(): + attn_processor = self.unet + for sub_key in key.split("."): + attn_processor = getattr(attn_processor, sub_key) + + # Process non-attention layers, which don't have to_{k,v,q,out_proj}_lora layers + # or add_{k,v,q,out_proj}_proj_lora layers. + rank = value_dict["lora.down.weight"].shape[0] + + if isinstance(attn_processor, LoRACompatibleConv): + conv_fuse_lora( + attn_processor, + value_dict, + lora_scale, + mapped_network_alphas.get(key), + rank, + offload_device=offload_device, + offload_weight=offload_weight, + ) + elif isinstance(attn_processor, LoRACompatibleLinear): + linear_fuse_lora( + attn_processor, + value_dict, + lora_scale, + mapped_network_alphas.get(key), + rank, + offload_device=offload_device, + offload_weight=offload_weight, + ) + else: + raise ValueError( + f"[OneDiff load_and_fuse_lora] Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module." + ) + else: + raise ValueError( + f"[OneDiff load_and_fuse_lora] {pretrained_model_name_or_path_or_dict} does not seem to be in the correct format expected by LoRA training." + ) + + is_model_cpu_offload = False + is_sequential_cpu_offload = False + + if not USE_PEFT_BACKEND: + if _pipeline is not None: + for _, component in _pipeline.components.items(): + if isinstance(component, torch.nn.Module) and hasattr( + component, "_hf_hook" + ): + is_model_cpu_offload = isinstance( + getattr(component, "_hf_hook"), CpuOffload + ) + is_sequential_cpu_offload = isinstance( + getattr(component, "_hf_hook"), AlignDevicesHook + ) + + logger.info( + "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." + ) + remove_hook_from_module( + component, recurse=is_sequential_cpu_offload + ) + + # self.to(dtype=self.dtype, device=self.device) + + # Offload back. + if is_model_cpu_offload: + _pipeline.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + _pipeline.enable_sequential_cpu_offload() + # Unsafe code /> + + # load lora weights + text_encoder_state_dict = { + k: v for k, v in state_dict.items() if "text_encoder." in k + } + if len(text_encoder_state_dict) > 0: + self.load_lora_into_text_encoder( + text_encoder_state_dict, + network_alphas=network_alphas, + text_encoder=self.text_encoder, + prefix="text_encoder", + lora_scale=self.lora_scale, + adapter_name=adapter_name, + _pipeline=self, + ) + + text_encoder_2_state_dict = { + k: v for k, v in state_dict.items() if "text_encoder_2." in k + } + if len(text_encoder_2_state_dict) > 0: + self.load_lora_into_text_encoder( + text_encoder_2_state_dict, + network_alphas=network_alphas, + text_encoder=self.text_encoder_2, + prefix="text_encoder_2", + lora_scale=self.lora_scale, + adapter_name=adapter_name, + _pipeline=self, + ) + + +def unfuse_lora(self: torch.nn.Module): + def _unfuse_lora(m: torch.nn.Module): + if isinstance(m, torch.nn.Linear): + linear_unfuse_lora(m) + elif isinstance(m, torch.nn.Conv2d): + conv_unfuse_lora(m) + + self.apply(_unfuse_lora) + + +class LRUCacheDict(OrderedDict): + def __init__(self, capacity): + super().__init__() + self.capacity = capacity + + def __getitem__(self, key): + value = super().__getitem__(key) + self.move_to_end(key) + return value + + def __setitem__(self, key, value): + if len(self) >= self.capacity: + oldest_key = next(iter(self)) + del self[oldest_key] + super().__setitem__(key, value) + + +def load_state_dict_cached( + lora: Union[str, Path, Dict[str, torch.Tensor]], **kwargs, +) -> Tuple[Dict, Dict]: + assert isinstance(lora, (str, Path, dict)) + if isinstance(lora, dict): + state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(lora, **kwargs) + return state_dict, network_alphas + + global CachedLoRAs + weight_name = kwargs.get("weight_name", None) + + lora_name = str(lora) + (f"/{weight_name}" if weight_name else "") + if lora_name in CachedLoRAs: + logger.debug(f"[OneDiff Cached LoRA] get cached lora of name: {str(lora_name)}") + return CachedLoRAs[lora_name] + + state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(lora, **kwargs,) + CachedLoRAs[lora_name] = (state_dict, network_alphas) + logger.debug(f"[OneDiff Cached LoRA] create cached lora of name: {str(lora_name)}") + return state_dict, network_alphas + + +CachedLoRAs = LRUCacheDict(100) diff --git a/src/onediff/infer_compiler/utils/model_inplace_assign.py b/src/onediff/infer_compiler/utils/model_inplace_assign.py index 24c32ef24..def3df809 100644 --- a/src/onediff/infer_compiler/utils/model_inplace_assign.py +++ b/src/onediff/infer_compiler/utils/model_inplace_assign.py @@ -1,3 +1,4 @@ +import warnings from typing import Union, List from collections import defaultdict import torch @@ -38,6 +39,8 @@ class TensorInplaceAssign: (True, False) """ def __init__(self, *modules: List[Union[torch.nn.Module, DeployableModule]]) -> None: + warnings.warn("The class TensorInplaceAssign is deprecated and will be removed soon. \ + If you are using _fuse_lora in TensorInplaceAssign, please check `onediff.utils.load_and_fuse_lora`") self.modules = set() for module in modules: if isinstance(module, torch.nn.Module):