Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Community Pipeline] Checkpoint merging #877

Closed
osanseviero opened this issue Oct 17, 2022 · 20 comments
Closed

[Community Pipeline] Checkpoint merging #877

osanseviero opened this issue Oct 17, 2022 · 20 comments

Comments

@osanseviero
Copy link
Contributor

Intro

Community Pipelines are introduced in diffusers==0.4.0 with the idea of allowing the community to quickly add, integrate, and share their custom pipelines on top of diffusers.

You can find a guide about Community Pipelines here. You can also find all the community examples under examples/community/. If you have questions about the Community Pipelines feature, please head to the parent issue.

Idea: Checkpoint Merging

This pipeline aims to merge two checkpoints into one via interpolation.

@osanseviero
Copy link
Contributor Author

@apolinario any resources you would suggest for this pipeline?

@apolinario
Copy link
Collaborator

Yes, so checkpoint merging is an idea implemented in the AUTOMATIC1111/stable-diffusion-webui repo, it interpolates the weights of models (can be fine-tuned or different versions of models), potentially creating cool results:
image

This is how it is implemented in the original repo

@Abhinay1997
Copy link
Contributor

Would the following be a good start ?

# STEP 1:
# Verify that the checkpoints have the same dimensions/modules etc.
# STEP 2:
# Find the mergeable modules from both the checkpoints. ( vae, unet, safety checker etc.. )
# STEP 3:
# For each mergeable component,  use the interpolation technique on the component weights and update the weights
# STEP 4:
# Return a pipeline with the merged weights.

So you would run it like this:-

pipe = CheckpointMergerPipeline.from_pretrained(chkpt0 =  "sample/checkpoint-1", chkpt1 = "sample/checkpoint-2", alpha = 0.2, interp = "sigmoid")

pipe.save_pretrained()

prompt = "A cat riding a skateboard in an 18th century street at night, moon in the background"

pipe.to("cuda")
pipe(prompt).images[0]

What do you think ? @apolinario @osanseviero

@patrickvonplaten
Copy link
Contributor

That seems sensible to me

@vvvm23
Copy link
Contributor

vvvm23 commented Oct 25, 2022

Hey @Abhinay1997, are you still working on this? If not, @patrickvonplaten may I please give this a shot? 🙂

@Abhinay1997
Copy link
Contributor

Hi @vvvm23, Currently working on a solution. Give me 24 hrs to make a PR. I'll let you know if my approach fails and maybe you can pick it up ?

@vvvm23
Copy link
Contributor

vvvm23 commented Oct 25, 2022

No worries! Please take your time!

@Abhinay1997
Copy link
Contributor

@patrickvonplaten, @apolinario

I ran into the following issue and wanted your thoughts:-

When building a custom pipeline inheriting from DiffusionPipeline, for checkpoint merging I have to do it this way:-

pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline = "checkpoint_merger_pipeline")
## Merge v1.4 weights with v1.2 weights of stable diffusion.
pipe.merge("CompVis/stable-diffusion-v1-2")

The problem here is that the custom_pipeline is loaded via the from_pretrained method of DiffusionPipeline. If we have to ensure that any kind of checkpoints are mergeable, we need to keep the modules dynamic. However the from_pretrained
method expects the custom_pipeline class to declare the kwargs in advance and results in an empty pipeline otherwise.

One solution is to pass the original checkpoint again in the merge method like:-

pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline = "checkpoint_merger_pipeline")
## Merge v1.4 weights with v1.2 weights of stable diffusion.
## SEE EXTRA ARGS HERE
pipe.merge("CompVis/stable-diffusion-v1-4", "CompVis/stable-diffusion-v1-2")

But it would result in the from_pretrained call being redundant other than to return an instance of the custom_pipeline class and result in confusion.

What would you suggest ?

@patrickvonplaten
Copy link
Contributor

Really up to you as it's a community pipeline :-) In general we cannot assume to be able to merge weights at a low level, so I'm also not 100% sure of the usefulness here - @apolinario do you have a good idea?

@vvvm23
Copy link
Contributor

vvvm23 commented Oct 27, 2022

Just my random thoughts, but why make this a pipeline instead of something functional? For example, a function with args module, ckpt1, ckpt2, alpha?

@Abhinay1997
Copy link
Contributor

Update:- I've reached a point in the solution where validation of the checkpoints is complete and the individual interpolation steps being complete. The only thing that needs to be done is returning a DiffusionPipeline object with the updated weights.

Should be done by the weekend.

@Abhinay1997
Copy link
Contributor

Hey team. Sorry it took a while. But here's the colab notebook with my code. Need your help in figuring out how to cut down on memory issues. Currently it crashes for the scenario in the notebook ( 12 GB RAM ) Will try it out on Kaggle ( 16 GB RAM ) and see if it's any better.

@Abhinay1997
Copy link
Contributor

Success ! Able to run on Kaggle. Uses around 13GB RAM for merging Stable Diffusion and Waifu diffusion. I'll make a PR over the weekend. I have made the notebook public if anyone's interested.

https://www.kaggle.com/abhinaydevarinti/checkpoint-merging-huggingface-diffusers

@patrickvonplaten
Copy link
Contributor

Thanks @Abhinay1997

@brucethemoose
Copy link

Could more "advanced" merging be supported by similar pipelines?

Namely this, which produces better results when merging models with different "base" SD ancestors: https://github.com/ogkalu2/Merge-Stable-Diffusion-models-without-distortion

And this, which allows for specifying different weights for different blocks (with the benefits specified in the linked blog post): https://github.com/bbc-mc/sdweb-merge-block-weighted-gui

@patrickvonplaten
Copy link
Contributor

Gently pinging @Abhinay1997 here :-)

@Abhinay1997
Copy link
Contributor

@brucethemoose, I went through the post that you mentioned briefly. From what I gather this method does a permutation of weights and does an optimal match for merging. From one of the repo's linked above I saw that the matching is different depending on the network architecture.
However I might be wrong. Let me look into this in more detail.

As for now, this pipeline is a simple interpolation where it blindly merges matched state_dicts as long as they are compatible and skips over incompatible ones. Frankly I think it should be a standalone pipeline on its own. However let me go over this in detail and get back to you. :)

@damian0815
Copy link
Contributor

damian0815 commented Feb 19, 2023

@Abhinay1997 @brucethemoose i've implemented block-weighted merging in grate , based on a modified version of checkpoint_merger. check it out here - (https://github.com/damian0815/grate/blob/main/src/sdgrate/checkpoint_merger_mbw.py) (or just pip install sdgrate and then run grate). Maybe my changes could be folded into your one @Abhinay1997 ?

@Abhinay1997
Copy link
Contributor

Abhinay1997 commented Feb 19, 2023

Damian,

Your implementation of the block-weighted merging is super cool ! My motivation with CheckpointMergingPipeline was the hope that it would work as a general purpose merger for all modules of the passed checkpoint. In line with this, I think we can make the module check ( UNet2DConditionModel ) to be dynamic ( another argument to the merge method ) and the block_weights a nested dict instead.

Of course, these are my thoughts. Would like to know what you think

@damian0815
Copy link
Contributor

@Abhinay1997 i made a PR #2422

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

7 participants