Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support : OFT merge to base model #1580

Merged
merged 4 commits into from
Sep 13, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 144 additions & 48 deletions networks/sdxl_merge_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
from library import sai_model_spec, sdxl_model_util, train_util
import library.model_util as model_util
import lora
import oft
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
import concurrent.futures

def load_state_dict(file_name, dtype):
if os.path.splitext(file_name)[1] == ".safetensors":
Expand Down Expand Up @@ -39,82 +41,176 @@ def save_to_file(file_name, model, state_dict, dtype, metadata):
else:
torch.save(model, file_name)

def detect_method_from_training_model(models, dtype):
for model in models:
lora_sd, _ = load_state_dict(model, dtype)
for key in tqdm(lora_sd.keys()):
if 'lora_up' in key or 'lora_down' in key:
return 'LoRA'
elif "oft_blocks" in key:
return 'OFT'

def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_dtype):
text_encoder1.to(merge_dtype)
text_encoder1.to(merge_dtype)
unet.to(merge_dtype)

# detect the method: OFT or LoRA_module
method = detect_method_from_training_model(models, merge_dtype)
logger.info(f"method:{method}")

# create module map
name_to_module = {}
for i, root_module in enumerate([text_encoder1, text_encoder2, unet]):
if i <= 1:
if i == 0:
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1
if method == 'LoRA':
if i <= 1:
if i == 0:
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1
else:
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER2
target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
else:
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER2
target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
else:
prefix = lora.LoRANetwork.LORA_PREFIX_UNET
target_replace_modules = (
prefix = lora.LoRANetwork.LORA_PREFIX_UNET
target_replace_modules = (
lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
)
elif method == 'OFT':
prefix = oft.OFTNetwork.OFT_PREFIX_UNET
target_replace_modules = (
oft.OFTNetwork.UNET_TARGET_REPLACE_MODULE_ALL_LINEAR + oft.OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
)

for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules:
for child_name, child_module in module.named_modules():
if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d":
lora_name = prefix + "." + name + "." + child_name
lora_name = lora_name.replace(".", "_")
name_to_module[lora_name] = child_module

if method == 'LoRA':
if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d":
lora_name = prefix + "." + name + "." + child_name
lora_name = lora_name.replace(".", "_")
name_to_module[lora_name] = child_module
elif method == 'OFT':
if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d":
oft_name = prefix + "." + name + "." + child_name
oft_name = oft_name.replace(".", "_")
name_to_module[oft_name] = child_module


for model, ratio in zip(models, ratios):
logger.info(f"loading: {model}")
lora_sd, _ = load_state_dict(model, merge_dtype)

logger.info(f"merging...")
for key in tqdm(lora_sd.keys()):
if "lora_down" in key:
up_key = key.replace("lora_down", "lora_up")
alpha_key = key[: key.index("lora_down")] + "alpha"

# find original module for this lora
module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
if method == 'LoRA':
for key in tqdm(lora_sd.keys()):
if "lora_down" in key:
up_key = key.replace("lora_down", "lora_up")
alpha_key = key[: key.index("lora_down")] + "alpha"

# find original module for this lora
module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
if module_name not in name_to_module:
logger.info(f"no module found for LoRA weight: {key}")
continue
module = name_to_module[module_name]
# logger.info(f"apply {key} to {module}")

down_weight = lora_sd[key]
up_weight = lora_sd[up_key]

dim = down_weight.size()[0]
alpha = lora_sd.get(alpha_key, dim)
scale = alpha / dim

# W <- W + U * D
weight = module.weight
# logger.info(module_name, down_weight.size(), up_weight.size())
if len(weight.size()) == 2:
# linear
weight = weight + ratio * (up_weight @ down_weight) * scale
elif down_weight.size()[2:4] == (1, 1):
# conv2d 1x1
weight = (
weight
+ ratio
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
* scale
)
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
weight = weight + ratio * conved * scale

module.weight = torch.nn.Parameter(weight)


elif method == 'OFT':

multiplier=1.0
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

for key in tqdm(lora_sd.keys()):
if "oft_blocks" in key:
oft_blocks = lora_sd[key]
dim = oft_blocks.shape[0]
break
for key in tqdm(lora_sd.keys()):
if "alpha" in key:
oft_blocks = lora_sd[key]
alpha = oft_blocks.item()
break

def merge_to(key):
if "alpha" in key:
return

# find original module for this OFT
module_name = ".".join(key.split(".")[:-1])
if module_name not in name_to_module:
logger.info(f"no module found for LoRA weight: {key}")
continue
return
module = name_to_module[module_name]
# logger.info(f"apply {key} to {module}")

down_weight = lora_sd[key]
up_weight = lora_sd[up_key]

dim = down_weight.size()[0]
alpha = lora_sd.get(alpha_key, dim)
scale = alpha / dim

# W <- W + U * D
weight = module.weight
# logger.info(module_name, down_weight.size(), up_weight.size())
if len(weight.size()) == 2:
# linear
weight = weight + ratio * (up_weight @ down_weight) * scale
elif down_weight.size()[2:4] == (1, 1):
# conv2d 1x1
weight = (
weight
+ ratio
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
* scale
)
# logger.info(f"apply {key} to {module}")

oft_blocks = lora_sd[key]

if isinstance(module, torch.nn.Linear):
out_dim = module.out_features
elif isinstance(module, torch.nn.Conv2d):
out_dim = module.out_channels

num_blocks = dim
block_size = out_dim // dim
constraint = (0 if alpha is None else alpha) * out_dim

block_Q = oft_blocks - oft_blocks.transpose(1, 2)
norm_Q = torch.norm(block_Q.flatten())
new_norm_Q = torch.clamp(norm_Q, max=constraint)
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
I = torch.eye(block_size, device=oft_blocks.device).unsqueeze(0).repeat(num_blocks, 1, 1)
block_R = torch.matmul(I + block_Q, (I - block_Q).inverse())
block_R_weighted = multiplier * block_R + (1 - multiplier) * I
R = torch.block_diag(*block_R_weighted)

# get org weight
org_sd = module.state_dict()
org_weight = org_sd["weight"].to(device)

R = R.to(org_weight.device, dtype=org_weight.dtype)

if org_weight.dim() == 4:
weight = torch.einsum("oihw, op -> pihw", org_weight, R)
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
weight = weight + ratio * conved * scale

weight = torch.einsum("oi, op -> pi", org_weight, R)

weight = weight.contiguous() # Make Tensor contiguous; required due to ThreadPoolExecutor

module.weight = torch.nn.Parameter(weight)

with concurrent.futures.ThreadPoolExecutor() as executor:
list(tqdm(executor.map(merge_to, lora_sd.keys()), total=len(lora_sd.keys())))


def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
base_alphas = {} # alpha for merged model
Expand Down
Loading