Skip to content

Commit

Permalink
Merge branch 'dev' into sd3
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Sep 13, 2024
2 parents cefe526 + 9d28607 commit f3ce80e
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 50 deletions.
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,13 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) frodo821 氏に感謝します

- `gen_imgs.py` のプロンプトオプションに、保存時のファイル名を指定する `--f` オプションを追加しました。また同スクリプトで Diffusers ベースのキーを持つ LoRA の重みに対応しました。


### Sep 13, 2024 / 2024-09-13:

- `sdxl_merge_lora.py` now supports OFT. Thanks to Maru-mee for the PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580). Will be included in the next release.

- `sdxl_merge_lora.py` が OFT をサポートしました。PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580) Maru-mee 氏に感謝します。次のリリースに含まれます。

### Jun 23, 2024 / 2024-06-23:

- Fixed `cache_latents.py` and `cache_text_encoder_outputs.py` not working. (Will be included in the next release.)
Expand Down
204 changes: 154 additions & 50 deletions networks/sdxl_merge_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,15 @@
from library import sai_model_spec, sdxl_model_util, train_util
import library.model_util as model_util
import lora
import oft
from library.utils import setup_logging

setup_logging()
import logging

logger = logging.getLogger(__name__)
import concurrent.futures


def load_state_dict(file_name, dtype):
if os.path.splitext(file_name)[1] == ".safetensors":
Expand Down Expand Up @@ -40,24 +45,45 @@ def save_to_file(file_name, model, state_dict, dtype, metadata):
torch.save(model, file_name)


def detect_method_from_training_model(models, dtype):
for model in models:
lora_sd, _ = load_state_dict(model, dtype)
for key in tqdm(lora_sd.keys()):
if "lora_up" in key or "lora_down" in key:
return "LoRA"
elif "oft_blocks" in key:
return "OFT"


def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_dtype):
text_encoder1.to(merge_dtype)
text_encoder1.to(merge_dtype)
unet.to(merge_dtype)

# detect the method: OFT or LoRA_module
method = detect_method_from_training_model(models, merge_dtype)
logger.info(f"method:{method}")

# create module map
name_to_module = {}
for i, root_module in enumerate([text_encoder1, text_encoder2, unet]):
if i <= 1:
if i == 0:
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1
if method == "LoRA":
if i <= 1:
if i == 0:
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1
else:
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER2
target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
else:
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER2
target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
else:
prefix = lora.LoRANetwork.LORA_PREFIX_UNET
prefix = lora.LoRANetwork.LORA_PREFIX_UNET
target_replace_modules = (
lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
)
elif method == "OFT":
prefix = oft.OFTNetwork.OFT_PREFIX_UNET
# ALL_LINEAR includes ATTN_ONLY, so we don't need to specify ATTN_ONLY
target_replace_modules = (
lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
oft.OFTNetwork.UNET_TARGET_REPLACE_MODULE_ALL_LINEAR + oft.OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
)

for name, module in root_module.named_modules():
Expand All @@ -73,48 +99,119 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
lora_sd, _ = load_state_dict(model, merge_dtype)

logger.info(f"merging...")
for key in tqdm(lora_sd.keys()):
if "lora_down" in key:
up_key = key.replace("lora_down", "lora_up")
alpha_key = key[: key.index("lora_down")] + "alpha"

# find original module for this lora
module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
if method == "LoRA":
for key in tqdm(lora_sd.keys()):
if "lora_down" in key:
up_key = key.replace("lora_down", "lora_up")
alpha_key = key[: key.index("lora_down")] + "alpha"

# find original module for this lora
module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
if module_name not in name_to_module:
logger.info(f"no module found for LoRA weight: {key}")
continue
module = name_to_module[module_name]
# logger.info(f"apply {key} to {module}")

down_weight = lora_sd[key]
up_weight = lora_sd[up_key]

dim = down_weight.size()[0]
alpha = lora_sd.get(alpha_key, dim)
scale = alpha / dim

# W <- W + U * D
weight = module.weight
# logger.info(module_name, down_weight.size(), up_weight.size())
if len(weight.size()) == 2:
# linear
weight = weight + ratio * (up_weight @ down_weight) * scale
elif down_weight.size()[2:4] == (1, 1):
# conv2d 1x1
weight = (
weight
+ ratio
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
* scale
)
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
weight = weight + ratio * conved * scale

module.weight = torch.nn.Parameter(weight)

elif method == "OFT":

multiplier = 1.0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

for key in tqdm(lora_sd.keys()):
if "oft_blocks" in key:
oft_blocks = lora_sd[key]
dim = oft_blocks.shape[0]
break
for key in tqdm(lora_sd.keys()):
if "alpha" in key:
oft_blocks = lora_sd[key]
alpha = oft_blocks.item()
break

def merge_to(key):
if "alpha" in key:
return

# find original module for this OFT
module_name = ".".join(key.split(".")[:-1])
if module_name not in name_to_module:
logger.info(f"no module found for LoRA weight: {key}")
continue
logger.info(f"no module found for OFT weight: {key}")
return
module = name_to_module[module_name]

# logger.info(f"apply {key} to {module}")

down_weight = lora_sd[key]
up_weight = lora_sd[up_key]

dim = down_weight.size()[0]
alpha = lora_sd.get(alpha_key, dim)
scale = alpha / dim

# W <- W + U * D
weight = module.weight
# logger.info(module_name, down_weight.size(), up_weight.size())
if len(weight.size()) == 2:
# linear
weight = weight + ratio * (up_weight @ down_weight) * scale
elif down_weight.size()[2:4] == (1, 1):
# conv2d 1x1
weight = (
weight
+ ratio
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
* scale
)
oft_blocks = lora_sd[key]

if isinstance(module, torch.nn.Linear):
out_dim = module.out_features
elif isinstance(module, torch.nn.Conv2d):
out_dim = module.out_channels

num_blocks = dim
block_size = out_dim // dim
constraint = (0 if alpha is None else alpha) * out_dim

block_Q = oft_blocks - oft_blocks.transpose(1, 2)
norm_Q = torch.norm(block_Q.flatten())
new_norm_Q = torch.clamp(norm_Q, max=constraint)
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
I = torch.eye(block_size, device=oft_blocks.device).unsqueeze(0).repeat(num_blocks, 1, 1)
block_R = torch.matmul(I + block_Q, (I - block_Q).inverse())
block_R_weighted = multiplier * block_R + (1 - multiplier) * I
R = torch.block_diag(*block_R_weighted)

# get org weight
org_sd = module.state_dict()
org_weight = org_sd["weight"].to(device)

R = R.to(org_weight.device, dtype=org_weight.dtype)

if org_weight.dim() == 4:
weight = torch.einsum("oihw, op -> pihw", org_weight, R)
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
weight = weight + ratio * conved * scale
weight = torch.einsum("oi, op -> pi", org_weight, R)

weight = weight.contiguous() # Make Tensor contiguous; required due to ThreadPoolExecutor

module.weight = torch.nn.Parameter(weight)

# TODO multi-threading may cause OOM on CPU if cpu_count is too high and RAM is not enough
max_workers = 1 if device.type != "cpu" else None # avoid OOM on GPU
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
list(tqdm(executor.map(merge_to, lora_sd.keys()), total=len(lora_sd.keys())))


def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
base_alphas = {} # alpha for merged model
Expand Down Expand Up @@ -164,7 +261,7 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
for key in tqdm(lora_sd.keys()):
if "alpha" in key:
continue

if "lora_up" in key and concat:
concat_dim = 1
elif "lora_down" in key and concat:
Expand All @@ -178,8 +275,8 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
alpha = alphas[lora_module_name]

scale = math.sqrt(alpha / base_alpha) * ratio
scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。
scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。

if key in merged_sd:
assert (
merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None
Expand All @@ -201,7 +298,7 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
dim = merged_sd[key_down].shape[0]
perm = torch.randperm(dim)
merged_sd[key_down] = merged_sd[key_down][perm]
merged_sd[key_up] = merged_sd[key_up][:,perm]
merged_sd[key_up] = merged_sd[key_up][:, perm]

logger.info("merged model")
logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
Expand Down Expand Up @@ -229,7 +326,9 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):


def merge(args):
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
assert len(args.models) == len(
args.ratios
), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"

def str_to_dtype(p):
if p == "float":
Expand Down Expand Up @@ -316,10 +415,16 @@ def setup_parser() -> argparse.ArgumentParser:
help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする",
)
parser.add_argument(
"--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
"--save_to",
type=str,
default=None,
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors",
)
parser.add_argument(
"--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors"
"--models",
type=str,
nargs="*",
help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors",
)
parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
parser.add_argument(
Expand All @@ -337,8 +442,7 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--shuffle",
action="store_true",
help="shuffle lora weight./ "
+ "LoRAの重みをシャッフルする",
help="shuffle lora weight./ " + "LoRAの重みをシャッフルする",
)

return parser
Expand Down

0 comments on commit f3ce80e

Please sign in to comment.