Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-basin and Stable Diffusion Tensor Flow weights #5

Open
ogkalu2 opened this issue Nov 4, 2022 · 35 comments
Open

Re-basin and Stable Diffusion Tensor Flow weights #5

ogkalu2 opened this issue Nov 4, 2022 · 35 comments

Comments

@ogkalu2
Copy link

ogkalu2 commented Nov 4, 2022

Hi Samuel. Thank you for being willing to look at this. Basically, i'm trying to see if it is possible to merge stable diffusion models i've finetuned with dreambooth with your method.

The first hurdle of course is that your implementation is not yet compatible with pytorch as far as i know. But the pytorch weights can be successfully converted to Tensorflow weights. This has been done before as well. I don't mind doing this if i have to.

The second hurdle will be successfully using your implementation on the TF models. I'm not sure how feasible this all is or how i would use your code on the SD TensorFlow weights.

@ogkalu2 ogkalu2 closed this as completed Nov 10, 2022
@samuela
Copy link
Owner

samuela commented Nov 11, 2022

Hi @ogkalu2! The exact framework that the model runs in is not super important. For example, we have used our JAX code to align the weights of two PyTorch models in the past. The only important part is that you can load the weights into Python/JAX and that you have a correct PermutationSpec for your model.

In general writing down the PermutationSpec will be the more challenging step. We are currently working on a pure-PyTorch version, including a "tracer" that can automatically generate the PermutationSpec for you but this is not ready for release just yet.

Would be very cool to see if this works on StableDiffusion models! Do let me know if you get it working!

@ogkalu2 ogkalu2 reopened this Nov 11, 2022
@ogkalu2
Copy link
Author

ogkalu2 commented Nov 11, 2022

Hi @samuela Thanks for responding. Quite a bit's happened since the last comment.

I found a repo that had converted the code to pytorch. I also actually got down to writing the permutation spec for stable diffusion today. I think i'm on the right track but i'm not too sure. For instance, i'm not quite sure what the correct p_in and p_out values should be. Running it right now gives a couple different errors each time.

Sometimes i get something like

File "/content/drive/MyDrive/SD_rebasin/weight_matching.py", line 487, in weight_matching
A += w_a @ w_b.T
RuntimeError: mat1 and mat2 shapes cannot be multiplied (1280x11520 and 5760x1280)

or

"addmm_impl_cpu_" not implemented for 'Half'

Can you take a quick look here and see what you think I might be doing wrong ?

https://imgur.com/a/DMnb7P8

@samuela
Copy link
Owner

samuela commented Nov 11, 2022

This looks to be an error in your permutation spec, I would try debugging what weight arrays that's occurring on.

@ogkalu2
Copy link
Author

ogkalu2 commented Nov 11, 2022

This looks to be an error in your permutation spec, I would try debugging what weight arrays that's occurring on.

I'll try that thanks.

I also get "RuntimeError: INDICES element is out of DATA bounds, id=256 axis_dim=256"

@ogkalu2
Copy link
Author

ogkalu2 commented Nov 12, 2022

This looks to be an error in your permutation spec, I would try debugging what weight arrays that's occurring on.

I've tried a couple different tests now removing everything besides a few lines of layers that were explicitly labelled in the state dict to see if that was the issue and i could build from that.

Example here https://imgur.com/a/IbAEfdP and https://imgur.com/a/o6SgMk8

But i still get those weird errors.

@ogkalu2
Copy link
Author

ogkalu2 commented Nov 12, 2022

@samuela
I finally got it to run the final permutation function for the few explicit blocks/layers (i'll build from there once i see everything else works) !

I feel so close but i've hit another wall on the apply permutation line. It's giving me "Key Erorr: Betas"

The good news is that i'm pretty sure i know what's happening here. Betas is the first key in Stable Diffusion's State Dict. It seems to be stuck trying to apply the permutation to Betas but betas wasn't defined in the permutation_spec list of layers or blocks to alter. More importantly, it's not the only key in the state dict i thought best left undisturbed.

Is there any way i can get the apply permutation function to skip keys that it doesn't have altered values for ?

@samuela
Copy link
Owner

samuela commented Nov 12, 2022

Hmm, I'm not familiar with the stable diffusion architecture... is there a reason not to model permutations on the betas?

You could always add them to you PermutationSpec and just specify None for all the axes. That should make apply_permutation ignore them.

(Btw, since you're using the 3rd-party pytorch implementation some things may be different! That's a different codebase.)

@lopho
Copy link

lopho commented Nov 12, 2022

You can skip betas, it's not a layer of the actual architecture but a stored parameter used for sampling and generating timestep embeddings. Same goes for alphas_cumprod, sqrt_alphas, and a dozen others.
Anything that isn't either in diffusionmodel, first_stage_model or cond_stage_model is such a parameter.
Actually you can probably skip first_stage_model (VAE) and cond_stage_model (CLIP) as well as those will probably be the same between two models. If not it would still make more sense to spec them separately as they are indeed separate models that only interact with each other manually, e.g. they have no connection in the model graph.

UNet is the interesting part, and this resides in diffusionmodel.model.*

EDIT:
to clarify, these are the keys you can skip

betas
alphas_cumprod
alphas_cumprod_prev
sqrt_alphas_cumprod
sqrt_one_minus_alphas_cumprod
log_one_minus_alphas_cumprod
sqrt_recip_alphas_cumprod
sqrt_recipm1_alphas_cumprod
posterior_variance
posterior_log_variance_clipped
posterior_mean_coef1
posterior_mean_coef2
model_ema.decay
model_ema.num_updates

and these are the ones you can skip if you retain the same VAE and text encoder

cond_stage_model.*
first_stage_model.*

@ogkalu2
Copy link
Author

ogkalu2 commented Nov 12, 2022

Hmm, I'm not familiar with the stable diffusion architecture... is there a reason not to model permutations on the betas?

You could always add them to you PermutationSpec and just specify None for all the axes. That should make apply_permutation ignore them.

(Btw, since you're using the 3rd-party pytorch implementation some things may be different! That's a different codebase.)

Mostly 3 reasons

  • I can't tell the layer type of all the layers in the state_dict, only most of them. So i have to skip some to test it out first. Also some layers aren't really part of the architecture and will be the same between 2 models .

  • A few correctly identified layers don't work well for some reason. I get an error like

File "/content/drive/MyDrive/SD_rebasin/weight_matching.py", line 293, in weight_matching
w_a = torch.moveaxis(w_a, axis, 0).reshape((n, -1))
RuntimeError: shape '[512, -1]' is invalid for input of size 768

  • Strictly speaking, the Unet architecture begins with "model.diffusion_model.time_embed.0.weight" not betas so i thought i might skip it

For axes you mean the P_bgx and P_bgy values right ?. So None and None then ?

I have a question on that too. There are a lot of layers. I'm unsure how to correctly label them all. For the layers with 2 P values, do i just keep going sequentially ? I reach P_bg50 or so that way.

For the layers with only one , how would that work exactly ? It's a bit hard to tell when i need to go from say P_bg1 to P_bg2 and to P_bg3 and so on. The architecture is divided in 3 parts - The input blocks, the middle blocks and the output blocks. So I'm wondering, is it P_bg1, _bg2, P_bg3 for those set of blocks or something else ?

@ogkalu2
Copy link
Author

ogkalu2 commented Nov 12, 2022

You can skip betas, it's not a layer of the actual architecture but a stored parameter used for sampling and generating timestep embeddings. Same goes for alphas_cumprod, sqrt_alphas, and a dozen others. Anything that isn't either in diffusionmodel, first_stage_model or cond_stage_model is such a parameter. Actually you can probably skip first_stage_model (VAE) and cond_stage_model (CLIP) as well as those will probably be the same between two models. If not it would still make more sense to spec them separately as they are indeed separate models that only interact with each other manually, e.g. they have no connection in the model graph.

UNet is the interesting part, and this resides in diffusionmodel.model.*

EDIT: to clarify, these are the keys you can skip

betas
alphas_cumprod
alphas_cumprod_prev
sqrt_alphas_cumprod
sqrt_one_minus_alphas_cumprod
log_one_minus_alphas_cumprod
sqrt_recip_alphas_cumprod
sqrt_recipm1_alphas_cumprod
posterior_variance
posterior_log_variance_clipped
posterior_mean_coef1
posterior_mean_coef2
model_ema.decay
model_ema.num_updates

and these are the ones you can skip if you retain the same VAE and text encoder

cond_stage_model.*
first_stage_model.*

Thanks. Lots of dreambooth repos train the text encoder also now so i won't skip them.

I have some uncertainty on the type of certain layers.

The layers that have
emb_layers
proj_in
ptoj_out
transformer_blocks
skip_connection
self_attn
mid.attn

Do you have any idea ? Are they norm, conv, dense or neither ?

@lopho
Copy link

lopho commented Nov 12, 2022

  • emb_layers is SiLU + linear -> dense
  • proj_in is Conv2d
  • proj_out is Conv2d
  • transformer_blocks is BasicTransformerBlock -> CrossAttention, LayerNorm, FeedForward
    • CrossAttention: all linear layers -> dense
    • FeedForward: GEGLU + linear -> dense
      • GEGLU: linear + gelu -> dense
  • skip_connection: is either Identity or Conv2d (Identity has no weights, so going by key, its always Conv2d)
  • self_attn: CLIPAttention -> dense
  • mid.attn: AttnBlock -> Conv2d

Keep in mind that the different Attention layer types might be either all conv or dense, but they are not just sequentially chained.
Rather they do qkv queries and batch matrix mult on results. I'm not sure if that has any impact on the results of the permutations.

@ogkalu2
Copy link
Author

ogkalu2 commented Nov 12, 2022

  • emb_layers is SiLU + linear -> dense

  • proj_in is Conv2d

  • proj_out is Conv2d

  • transformer_blocks is BasicTransformerBlock -> CrossAttention, LayerNorm, FeedForward

    • CrossAttention: all linear layers -> dense

    • FeedForward: GEGLU + linear -> dense

      • GEGLU: linear + gelu -> dense
  • skip_connection: is either Identity or Conv2d (Identity has no weights, so going by key, its always Conv2d)

  • self_attn: CLIPAttention -> dense

  • mid.attn: AttnBlock -> Conv2d

Keep in mind that the different Attention layer types might be either all conv or dense, but they are not just sequentially chained. Rather they do qkv queries and batch matrix mult on results. I'm not sure if that has any impact on the results of the permutations.

Thanks for the response. It's helped a lot. Yes i think i'm going to skip the attention layers, at least for the first go around.
As samuel suggested, Labeling the axes as none for betas worked on skipping it so i just have to do them all now.

Forgot to ask, what are the time_embed (i'm assuming dense now), model.out.0 and out.2 (norm and conv i think) and .op (i think conv ) ?

@lopho
Copy link

lopho commented Nov 13, 2022

unet time_embed: linear, silu, linear, so its dense
unet out: GroupNorm, SiLU, Conv2d
op (which I guess you mean op of the unet downsampling block) is conv2d

It's all readily available in the implementation, so I suggest you read it yourself to get a better understanding.
unet: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/openaimodel.py
unet attention: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/attention.py
vae: https://github.com/CompVis/stable-diffusion/blob/main/ldm/models/autoencoder.py

@ogkalu2
Copy link
Author

ogkalu2 commented Nov 15, 2022

Just an update. At first I had trouble building up. Some layers would work, most wouldn't. But most irritatingly, it felt inconsistent on what would work without error and what wouldn't. Today I figured out the issue was the axis and the torch size of the layers. Not all layers can or should be connected by axis and the torch size help tell which ones can/should be. Anyway, I can get pretty much every later to permutate now. So I'll finish that and finally test this.

@affableroots
Copy link

@ogkalu2 So the idea is to go pytorch -> TF -> rebasin -> pytorch? This sounds huge btw, thanks for doing it.

@affableroots
Copy link

also, you probably saw this, but there's a PyTorch version, but I think you need to come up with a PermutationSpec: https://github.com/themrzmaster/git-re-basin-pytorch

@ogkalu2
Copy link
Author

ogkalu2 commented Nov 15, 2022

@ogkalu2 So the idea is to go pytorch -> TF -> rebasin -> pytorch? This sounds huge btw, thanks for doing it.

I'm pretty just using the pytorch implementation now. The one you linked, I already knew about it. Ended up using Jax for flattening and unflattening the parmeters but that's about it.

No problem. It's my pleasure. Done with the unet. Working on the text encoder. There's no doubt it'll run now. Just if it merges as hoped. Fingers crossed for that.

@lopho
Copy link

lopho commented Nov 15, 2022

@ogkalu2 Do you have a repository for this where I could take a look?

@ogkalu2
Copy link
Author

ogkalu2 commented Nov 15, 2022

@lopho No. I wanted to finish things and see the results of a merged model before i uploaded anything to a repo.

@ogkalu2
Copy link
Author

ogkalu2 commented Nov 15, 2022

@ogkalu2 Do you have a repository for this where I could take a look?

Although i did upload my first attempt here. A few things have changed to make it work, mostly the axes
https://imgur.com/a/DMnb7P8

But i added a bias option for the conv and added the dense emb layers in the easyblock

@ogkalu2
Copy link
Author

ogkalu2 commented Nov 16, 2022

Hi @samuela would this run much faster on a gpu ?

@samuela
Copy link
Owner

samuela commented Nov 16, 2022

Yes, it should run quite a bit faster on a GPU since that will speed up the matrix multiples but the linear assignment problem solve still happens on the CPU, so I don't think the speedup you'd get would be anything too crazy... I've never tried running on CPU only

@ogkalu2
Copy link
Author

ogkalu2 commented Nov 16, 2022

Yes, it should run quite a bit faster on a GPU since that will speed up the matrix multiples but the linear assignment problem solve still happens on the CPU, so I don't think the speedup you'd get would be anything too crazy... I've never tried running on CPU only

Ah I see. If I don't specify device as cpu, i get

File "/notebooks/weight_matching.py", line 798, in weight_matching
A += w_a @ w_b.T
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat2 in method wrapper_mm)

@ogkalu2
Copy link
Author

ogkalu2 commented Nov 16, 2022

@lopho @affableroots @samuela
I'm done with the spec. I've not tested a merge yet but that will take time so i've uploaded it here for anyone to test as well.
https://github.com/ogkalu2/Merge-Stable-Diffusion-models-without-distortion

@ogkalu2
Copy link
Author

ogkalu2 commented Nov 17, 2022

Hi @samuela How many iterations does the weight matching typically run ? I know max iterations is 100 but it doesn't usually go that high, i don't think ?

@samuela
Copy link
Owner

samuela commented Nov 17, 2022

It totally depends on the model and initialization. I've seen it take as few as 3 and as many as 50. It is guaranteed to terminate though, so don't worry it can't run forever!

@affableroots
Copy link

I keep OOMing on 32GB RAM, any tips on what I can delete when, or maybe running the merge in parts?

@ogkalu2
Copy link
Author

ogkalu2 commented Nov 17, 2022

Oh wow. I know it can't run forever but the 1st SD iteration took ~ 12 hours so I was curious. Ah well. What do the NewL - OldL values indicate exactly ? I see most of them are 0.0

@ogkalu2
Copy link
Author

ogkalu2 commented Nov 17, 2022

I keep OOMing on 32GB RAM, any tips on what I can delete when, or maybe running the merge in parts?

Really ? Huh. I'm just running on vast right now. You can't really run in parts right now. As for what to delete, it's possible to skip some layers but i honestly don't know exactly what i can skip yet. The vae layers would be the first thing i'd remove but i don't know besides that. I'll look into that.

The OOM errors seem odd though. Do you actually get those errors on your console/terminal or does your system freeze up or something ?

@affableroots
Copy link

affableroots commented Nov 17, 2022

Watching htop, I watch it OOM, and also, I get a Killed at the following step. If it matters, I'm testing just merging 2 simple 4GB .ckpts built off of sd-v1-4

...
0/P_model.diffusion_model.output_blocks.6.0_inner2: 0.0
0/P_bg308: 0.0
0/P_model.diffusion_model.middle_block.2_inner4: 0.0
0/P_bg353: 0.0
0/P_bg166: 0.0
0/P_bg180: 0.0
0/P_bg65: 0.0
0/P_bg78: 0.0
0/P_first_stage_model.encoder.mid.block_1_inner: 0.0
0/P_bg214: 0.0
0/P_first_stage_model.decoder.up.3.block.1_inner: 0.0
0/P_bg313: 0.0
0/P_model.diffusion_model.output_blocks.0.0_inner: 0.0
0/P_bg98: 0.0
0/P_bg359: 0.0
0/P_bg141: 0.0
0/P_bg264: 0.0
0/P_bg163: 0.0
Killed

EDIT: skipping the vae makes sense, that's a good idea.

@ogkalu2
Copy link
Author

ogkalu2 commented Nov 17, 2022

Oh i see. The test i have running has 2 dreambooth models pruned to 2GB. The bigger the size of the models, the higher the RAM usage. I didn't realize 4gb models were too much for 32 GB ram systems currently. The problem is the linear sum assignment. It can only run on the CPU

@ogkalu2
Copy link
Author

ogkalu2 commented Nov 17, 2022

@samuela
Have someone here who's running through this in minutes lol
ogkalu2/Merge-Stable-Diffusion-models-without-distortion#1 (comment)

Anyway i have a new problem now. So the perm spec runs fine and the parameters get updated fine. The previous line i wrote to save the model won't work. After defining the state dict(s) as state_a = model_a["state_dict"], i tried to save the model with
torch.save({
"state_dict": state_b(updated_params)
}, output_file)

but get hit with

"state_dict": state_b(updated_params)
TypeError: 'dict' object is not callable

@ogkalu2
Copy link
Author

ogkalu2 commented Nov 18, 2022

Hi @samuela
Something seems to be up with the get_permuted_params function. applying the permutation with it just seems to bias whatever the selected model parameters are.

So a merge with apply_permutation(permutation_spec, final_permutation, mode_a state dict) just produces a model that is basically model a and a merge with apply_permutation(permutation_spec, final_permutation, mode_b state dict) just produces a model that is basically model b.

Any idea what the issue might be ?

@samuela
Copy link
Owner

samuela commented Nov 23, 2022

So a merge with apply_permutation(permutation_spec, final_permutation, mode_a state dict) just produces a model that is basically model a and a merge with apply_permutation(permutation_spec, final_permutation, mode_b state dict) just produces a model that is basically model b.

Hi @ogkalu2, how are you measuring the difference between the permuted model and the original? Have you inspected the final_permutations to see if they are close to identity? Depending on the kind of fine-tuning you're doing, it could be reasonably expected that the optimal permutation is already close to identity. This is a convenient property of fine-tuning that you generally don't leave the pre-training basin, esp. with large models and small learning rates, see eg https://arxiv.org/abs/2109.01903, https://twitter.com/moyix/status/1581390268368302080, and so forth.

@liruiw
Copy link

liruiw commented Feb 1, 2023

Hi @ogkalu2! The exact framework that the model runs in is not super important. For example, we have used our JAX code to align the weights of two PyTorch models in the past. The only important part is that you can load the weights into Python/JAX and that you have a correct PermutationSpec for your model.

In general writing down the PermutationSpec will be the more challenging step. We are currently working on a pure-PyTorch version, including a "tracer" that can automatically generate the PermutationSpec for you but this is not ready for release just yet.

Would be very cool to see if this works on StableDiffusion models! Do let me know if you get it working!

Hello, I wonder if there is an update on the automatic tracer in pytorch for the permutation spec? Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants