-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[LoRA] Add LoRA training script #1884
Conversation
The documentation is not available anymore as the PR was closed or merged. |
This is great, but should it be part of diffusers? Why not have this as an external library. Maybe this is more of a meta-comment, but imho there is no need for diffusers to be everything. It should be the base where other libraries can build on. To me this seems to be both easier to the contributors/maintainers of the "advanced" libraries and also for diffusers as such, as there's bound to be a difference in development speed/cadence of these new and shiny methods and the core, it won't be pleasant to have to update an amalgamation of code just because the one of the new shiny embedded libraries advances. The cloneofsimo/lora repo works very well with diffusers, wouldn't it be better to do all lora related development there (or if for this implementation is incompatible, just a new repo so that there are two lora flavored libraries, which i think is preferable over just putting everything in diffusers) |
…add_lora_fine_tuning
I would say it is honor to have my project be an inspiration to become one of official huggingface source code, but I do have a feeling that we are reinventing the wheel here... |
Hey @jrd-rocks and @cloneofsimo, Thanks for your comments - it's super nice to see that other repositories such as https://github.com/cloneofsimo/lora are using @cloneofsimo, would it be ok if we state you as one of the authors of this script and link to your GitHub repo? (or would you maybe like to help with this PR to make you an author by commit?) This example script will have a couple of differences:
=> We intend this script rather as a long-term maintained example script of how to use LoRA, we're happy to refer to yours as "the" LoRA training script if you'd like :-) |
Actually, the main reason we opened this PR was because the community asked for it here: #1715 |
Super cool to see this development. However, I'm wondering why it was necessary to create new CrossAttention classes for LoRA? I can't figure out how this differs from how people have been applying @cloneofsimo 's repo. In case you don't want to pollute the PR, I've posted the question in another discussion here: cloneofsimo/lora#107 Thanks for all your efforts and any insight(s) you can offer! |
Hey @brian6091, The CrossAttention mechanism was not (just) introduced for LoRA, it's main usage is to be able to tweak attention weights at runtime as explained in: #1639 |
Thanks for the context @patrickvonplaten, I better understand the design decision now. |
Hi @patrickvonplaten , thank you for your kind explanations! I would love if you would reference like that for me. Thanks for the hard work! |
examples/lora/train_lora.py
Outdated
optimizer_class = torch.optim.AdamW | ||
|
||
# Optimizer creation | ||
params_to_optimize = itertools.chain(*[v.parameters() for v in unet.attn_processors.values()]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this also contain the text encoder parameters if train_text_encoder
is set to True?
examples/lora/train_lora.py
Outdated
pipeline = pipeline.to(accelerator.device) | ||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) | ||
pipeline.set_progress_bar_config(disable=True) | ||
sample_dir = "/home/patrick_huggingface_co/lora-tryout/samples" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You probably have noted it already but this needs to be more generic I guess.
for tracker in accelerator.trackers: | ||
if tracker.name == "wandb": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A cleaner way might be to
if wandb.run is not None:
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sayakpaul This comes from a code snippet where I did different things depending on whether the tracker was tensorflow, wandb or something else (there can be different trackers enabled). But yes, if we are only considering the case of wandb
we could maybe simplify it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@patrickvonplaten the report does not include prompts for the logged images, did you prepare it from a previous run?
@pcuenca @patil-suraj @sayakpaul - I think this is ready for a final review :-) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great to me, I like the loaders
API! I just have two questions:
- Why is the casted to
weight_dtype
? - Does generation work when mixed-precision training is enabled ?
if global_step >= args.max_train_steps: | ||
break | ||
|
||
if args.validation_prompt is not None and epoch % 10 == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we let the user control how often to generate, rather than hardcoding the value here ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes good point adding validation_epochs
# run inference | ||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) | ||
prompt = args.num_validation_images * [args.validation_prompt] | ||
images = pipeline(prompt, num_inference_steps=25, generator=generator).images |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe use autocast
here to do generation in fp16
. Have you verified this with mixed-precision ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Impressive work!
``` | ||
|
||
**___Note: When using LoRA we can use a much higher learning rate compared to vanilla dreambooth. Here we | ||
use *1e-4* instead of the usual *2e-6*.___** |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 Perfect
for tracker in accelerator.trackers: | ||
if tracker.name == "wandb": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sayakpaul This comes from a code snippet where I did different things depending on whether the tracker was tensorflow, wandb or something else (there can be different trackers enabled). But yes, if we are only considering the case of wandb
we could maybe simplify it.
for tracker in accelerator.trackers: | ||
if tracker.name == "wandb": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@patrickvonplaten the report does not include prompts for the logged images, did you prepare it from a previous run?
src/diffusers/loaders.py
Outdated
logger = logging.get_logger(__name__) | ||
|
||
|
||
ATTN_WEIGHT_NAME = "pytorch_attn_procs.bin" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually maybe it would make more sense to make the naming more general here:
ATTN_WEIGHT_NAME = "pytorch_attn_procs.bin" | |
ATTN_WEIGHT_NAME = "embeddings.bin" |
So that multiple loaders could be applied on the same file? cc @pcuenca @patil-suraj
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought about it, but wasn't sure. Another idea would be to make it more specific and descriptive, like lora_embeddings.bin
and then use different names for others. Not sure what would be easiest to deal with going forward.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if embeddings
is a good name, since these aren't embeddings, agree with Pedro, maybe make it specific to procs, for example lora_layers.bin
or lora_weights.bin
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, not a big fan of making it super specific in case, we will want to expand the functionality to more "adapter" layers in the future.
E.g. if someone wants to use both LoRA and textual inversion it'd be nicer to have everything in one file no?
=> going for adapter_weights.bin
now, ok for you?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright, @patil-suraj convinced me to go for LoRA weights specific name - this means in the longer run:
- We have different files for different parts of the pipeline
- The user will call multiple loading methods for different parts of the pipeline, e.g.:
from diffusers import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained("...")
pipe.unet.load_attn_procs("...")
pipe.load_text_embeddings("...")
But I think that's fine!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is going to be an enabler. I am telling you!
Co-authored-by: Pedro Cuenca <pedro@huggingface.co> Co-authored-by: Suraj Patil <surajp815@gmail.com>
…/diffusers into add_lora_fine_tuning
Update When running the script in mixed precision, it's running at 6.5 GB GPU RAM, but actually goes up to 14GB GPU for inference, but one can just generate fewer samples during inference or run the pipeline in a loop. |
14gb for inference!?!?! I do think this should be part of diffusers |
wandb is a requirement in the current code,that seems like a bug... |
@jtoy I could complete a fine-tuning run using a 2080 Ti (11 GB of RAM) :) And yes, I agree that |
@pcuenca what args did you use? I used my titan X 1080 with 12 gb and it dies with OOM. I used the example in the README: accelerate launch train_dreambooth_lora.py |
@jtoy Sorry for the confusion, I was referring to a complete fine-tuning using the |
Thank you for adding LoRa. |
* [Lora] first upload * add first lora version * upload * more * first training * up * correct * improve * finish loaders and inference * up * up * fix more * up * finish more * finish more * up * up * change year * revert year change * Change lines * Add cloneofsimo as co-author. Co-authored-by: Simo Ryu <cloneofsimo@gmail.com> * finish * fix docs * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> Co-authored-by: Suraj Patil <surajp815@gmail.com> * upload * finish Co-authored-by: Simo Ryu <cloneofsimo@gmail.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co> Co-authored-by: Suraj Patil <surajp815@gmail.com>
Update:
Training seems to work fine -> see some results here (after 4min of training on a A100): https://wandb.ai/patrickvonplaten/stable_diffusion_lora/reports/LoRA-training-results--VmlldzozMzI4MTI3?accessToken=d7x29esww3nvbrilo18hyto784w4oep721jiqgophgzdhztytwko1stcscp38gld
Possible API:
The premise of LoRA is to add weights to the model and only train those so that the fine-tuned weights result in some very small portable weights.
Therefore it is important to add a new "LoRA weights loading API" which is currently implemented as follows:
The idea is the following. During training only the loRA layers are saved which for the default rank=4 are only around 3MB: https://huggingface.co/patrickvonplaten/lora/blob/main/pytorch_attn_procs.bin
Those weights can then be downloaded easily from the Hub via a novel
load_lora
loading function as implemented here:https://github.com/huggingface/diffusers/pull/1884/files#r1069869084
Co-authors:
Co-authored by: https://github.com/cloneofsimo - the first that came up with the idea of using LoRA for stable diffusion in the popular "lora" repo: https://github.com/cloneofsimo/lora