-
Notifications
You must be signed in to change notification settings - Fork 6.1k
Description
Is your feature request related to a problem? Please describe.
Would be great to be able to load a LoRA to a model compiled with torch.compile
Describe the solution you'd like.
Do load_lora_weights
with a compiled pipe
(ideally without triggering recompilation)
Currently, running this code:
import torch
from diffusers import DiffusionPipeline
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe = pipe.to(device)
pipe.transformer = pipe.transformer.to(memory_format=torch.channels_last)
pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead")
pipe.load_lora_weights("multimodalart/flux-tarot-v1")
It errors:
Loading adapter weights from state_dict led to unexpected keys not found in the model: ['single_transformer_blocks.0.attn.to_k.lora_A.default_3.weight', 'single_transformer_blocks.0.attn.to_k.lora_B.default_3.weight', 'single_transformer_blocks.0.attn.to_q.lora_A.default_3.weight', 'single_transformer_blocks.0.attn.to_q.lora_B.default_3.weight',
When compiled, the state dict of a model seem to add a _orig_mod prefix to all keys
odict_keys(['_orig_mod.time_text_embed.timestep_embedder.linear_1.weight',...
Describe alternatives you've considered.
An alternative is to fuse the LoRA into the model and then compile, however this does not allow for hot swapping LoRAs (as a new pipeline and a new compilation is needed for every LoRA)
Additional context.
This seems to have been achieved by @chengzeyi , author of the now paused https://github.com/chengzeyi/stable-fast , however it seems to be part of the non-open source FAL optimized inference (however if you'd like to contribute this upstream, feel free!)