Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversion for stable-diffusion-webui #1

Closed
catboxanon opened this issue Mar 28, 2023 · 1 comment
Closed

Conversion for stable-diffusion-webui #1

catboxanon opened this issue Mar 28, 2023 · 1 comment

Comments

@catboxanon
Copy link

Hi, great research! Impressed by the results.

For possibly your own interest, and in case anybody else come across this, you can use this conversion script to get the LoRA models functioning with AUTOMATIC1111/stable-diffusion-webui, the interface majority of the SD community uses. Credit to harrywang for the original script.

import re
import os
import argparse
import torch
from safetensors.torch import save_file

def main(args):
    if torch.cuda.is_available():
        device = 'cuda'
        checkpoint = torch.load(args.file, map_location=torch.device('cuda'))
    else:
        device = 'cpu'
        checkpoint = torch.load(args.file, map_location=torch.device('cpu'))
    
    new_dict = dict()
    for idx, key in enumerate(checkpoint):
        new_key = re.sub('\.processor\.', '_', key)
        new_key = re.sub('mid_block\.', 'mid_block_', new_key)
        new_key = re.sub('_lora.up.', '.lora_up.', new_key)
        new_key = re.sub('_lora.down.', '.lora_down.', new_key)
        new_key = re.sub('\.(\d+)\.', '_\\1_', new_key)
        new_key = re.sub('to_out', 'to_out_0', new_key)
        new_key = 'lora_unet_' + new_key

        new_dict[new_key] = checkpoint[key]

    file_name = os.path.splitext(args.file)[0]
    new_lora_name = file_name + '_converted.safetensors'
    print("Saving " + new_lora_name)
    save_file(new_dict, new_lora_name)

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--file",
        type=str,
        default=None,
        required=True,
    )
    
    args = parser.parse_args()
    return args

if __name__ == "__main__":
    args = parse_args()
    main(args)
@tgxs002
Copy link
Owner

tgxs002 commented Mar 28, 2023

Thank you for pointing it out. I will update the README later.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants