Skip to content

Commit

Permalink
Merge pull request #872 from kohya-ss/dev
Browse files Browse the repository at this point in the history
fix make_captions_by_git, improve image generation scripts
  • Loading branch information
kohya-ss authored Oct 10, 2023
2 parents 33ee0ac + 681034d commit 2a23713
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 40 deletions.
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,21 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum

## Change History

### Oct 11, 2023 / 2023/10/11
- Fix to work `make_captions_by_git.py` with the latest version of transformers.
- Improve `gen_img_diffusers.py` and `sdxl_gen_img.py`. Both scripts now support the following options:
- `--network_merge_n_models` option can be used to merge some of the models. The remaining models aren't merged, so the multiplier can be changed, and the regional LoRA also works.
- `--network_regional_mask_max_color_codes` is added. Now you can use up to 7 regions.
- When this option is specified, the mask of the regional LoRA is the color code based instead of the channel based. The value is the maximum number of the color codes (up to 7).
- You can specify the mask for each LoRA by colors: 0x0000ff, 0x00ff00, 0x00ffff, 0xff0000, 0xff00ff, 0xffff00, 0xffffff.

- `make_captions_by_git.py` が最新の transformers で動作するように修正しました。
- `gen_img_diffusers.py``sdxl_gen_img.py` を更新し、以下のオプションを追加しました。
- `--network_merge_n_models` オプションで一部のモデルのみマージできます。残りのモデルはマージされないため、重みを変更したり、領域別LoRAを使用したりできます。
- `--network_regional_mask_max_color_codes` を追加しました。最大7つの領域を使用できます。
- このオプションを指定すると、領域別LoRAのマスクはチャンネルベースではなくカラーコードベースになります。値はカラーコードの最大数(最大7)です。
- 各LoRAに対してマスクをカラーで指定できます:0x0000ff、0x00ff00、0x00ffff、0xff0000、0xff00ff、0xffff00、0xffffff。

### Oct 9. 2023 / 2023/10/9

- `tag_images_by_wd_14_tagger.py` now supports Onnx. If you use Onnx, TensorFlow is not required anymore. [#864](https://github.com/kohya-ss/sd-scripts/pull/864) Thanks to Isotr0py!
Expand Down
6 changes: 5 additions & 1 deletion finetune/make_captions_by_git.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ def collate_fn_remove_corrupted(batch):


def main(args):
r"""
transformers 4.30.2で、バッチサイズ>1でも動くようになったので、以下コメントアウト
# GITにバッチサイズが1より大きくても動くようにパッチを当てる: transformers 4.26.0用
org_prepare_input_ids_for_generation = GenerationMixin._prepare_input_ids_for_generation
curr_batch_size = [args.batch_size] # ループの最後で件数がbatch_size未満になるので入れ替えられるように
Expand All @@ -65,6 +68,7 @@ def _prepare_input_ids_for_generation_patch(self, bos_token_id, encoder_outputs)
return input_ids
GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch
"""

print(f"load images from {args.train_data_dir}")
train_data_dir_path = Path(args.train_data_dir)
Expand All @@ -81,7 +85,7 @@ def _prepare_input_ids_for_generation_patch(self, bos_token_id, encoder_outputs)
def run_batch(path_imgs):
imgs = [im for _, im in path_imgs]

curr_batch_size[0] = len(path_imgs)
# curr_batch_size[0] = len(path_imgs)
inputs = git_processor(images=imgs, return_tensors="pt").to(DEVICE) # 画像はpil形式
generated_ids = git_model.generate(pixel_values=inputs.pixel_values, max_length=args.max_length)
captions = git_processor.batch_decode(generated_ids, skip_special_tokens=True)
Expand Down
70 changes: 50 additions & 20 deletions gen_img_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,13 @@
import diffusers
import numpy as np
import torch

try:
import intel_extension_for_pytorch as ipex

if torch.xpu.is_available():
from library.ipex import ipex_init

ipex_init()
except Exception:
pass
Expand Down Expand Up @@ -954,7 +957,7 @@ def __call__(
text_emb_last = torch.stack(text_emb_last)
else:
text_emb_last = text_embeddings

for i, t in enumerate(tqdm(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
Expand Down Expand Up @@ -2363,12 +2366,19 @@ def __getattr__(self, item):
network_default_muls = []
network_pre_calc = args.network_pre_calc

# merge関連の引数を統合する
if args.network_merge:
network_merge = len(args.network_module) # all networks are merged
elif args.network_merge_n_models:
network_merge = args.network_merge_n_models
else:
network_merge = 0

for i, network_module in enumerate(args.network_module):
print("import network module:", network_module)
imported_module = importlib.import_module(network_module)

network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
network_default_muls.append(network_mul)

net_kwargs = {}
if args.network_args and i < len(args.network_args):
Expand All @@ -2379,31 +2389,32 @@ def __getattr__(self, item):
key, value = net_arg.split("=")
net_kwargs[key] = value

if args.network_weights and i < len(args.network_weights):
network_weight = args.network_weights[i]
print("load network weights from:", network_weight)
if args.network_weights is None or len(args.network_weights) <= i:
raise ValueError("No weight. Weight is required.")

if model_util.is_safetensors(network_weight) and args.network_show_meta:
from safetensors.torch import safe_open
network_weight = args.network_weights[i]
print("load network weights from:", network_weight)

with safe_open(network_weight, framework="pt") as f:
metadata = f.metadata()
if metadata is not None:
print(f"metadata for: {network_weight}: {metadata}")
if model_util.is_safetensors(network_weight) and args.network_show_meta:
from safetensors.torch import safe_open

network, weights_sd = imported_module.create_network_from_weights(
network_mul, network_weight, vae, text_encoder, unet, for_inference=True, **net_kwargs
)
else:
raise ValueError("No weight. Weight is required.")
with safe_open(network_weight, framework="pt") as f:
metadata = f.metadata()
if metadata is not None:
print(f"metadata for: {network_weight}: {metadata}")

network, weights_sd = imported_module.create_network_from_weights(
network_mul, network_weight, vae, text_encoder, unet, for_inference=True, **net_kwargs
)
if network is None:
return

mergeable = network.is_mergeable()
if args.network_merge and not mergeable:
if network_merge and not mergeable:
print("network is not mergiable. ignore merge option.")

if not args.network_merge or not mergeable:
if not mergeable or i >= network_merge:
# not merging
network.apply_to(text_encoder, unet)
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
print(f"weights are loaded: {info}")
Expand All @@ -2417,6 +2428,7 @@ def __getattr__(self, item):
network.backup_weights()

networks.append(network)
network_default_muls.append(network_mul)
else:
network.merge_to(text_encoder, unet, weights_sd, dtype, device)

Expand Down Expand Up @@ -2712,9 +2724,18 @@ def resize_images(imgs, size):

size = None
for i, network in enumerate(networks):
if i < 3:
if (i < 3 and args.network_regional_mask_max_color_codes is None) or i < args.network_regional_mask_max_color_codes:
np_mask = np.array(mask_images[0])
np_mask = np_mask[:, :, i]

if args.network_regional_mask_max_color_codes:
# カラーコードでマスクを指定する
ch0 = (i + 1) & 1
ch1 = ((i + 1) >> 1) & 1
ch2 = ((i + 1) >> 2) & 1
np_mask = np.all(np_mask == np.array([ch0, ch1, ch2]) * 255, axis=2)
np_mask = np_mask.astype(np.uint8) * 255
else:
np_mask = np_mask[:, :, i]
size = np_mask.shape
else:
np_mask = np.full(size, 255, dtype=np.uint8)
Expand Down Expand Up @@ -3367,10 +3388,19 @@ def setup_parser() -> argparse.ArgumentParser:
"--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数"
)
parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する")
parser.add_argument(
"--network_merge_n_models", type=int, default=None, help="merge this number of networks / この数だけネットワークをマージする"
)
parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする")
parser.add_argument(
"--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する"
)
parser.add_argument(
"--network_regional_mask_max_color_codes",
type=int,
default=None,
help="max color codes for regional mask (default is None, mask by channel) / regional maskの最大色数(デフォルトはNoneでチャンネルごとのマスク)",
)
parser.add_argument(
"--textual_inversion_embeddings",
type=str,
Expand Down
69 changes: 50 additions & 19 deletions sdxl_gen_img.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@
import diffusers
import numpy as np
import torch

try:
import intel_extension_for_pytorch as ipex

if torch.xpu.is_available():
from library.ipex import ipex_init

ipex_init()
except Exception:
pass
Expand Down Expand Up @@ -1534,12 +1537,20 @@ def __getattr__(self, item):
network_default_muls = []
network_pre_calc = args.network_pre_calc

# merge関連の引数を統合する
if args.network_merge:
network_merge = len(args.network_module) # all networks are merged
elif args.network_merge_n_models:
network_merge = args.network_merge_n_models
else:
network_merge = 0
print(f"network_merge: {network_merge}")

for i, network_module in enumerate(args.network_module):
print("import network module:", network_module)
imported_module = importlib.import_module(network_module)

network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
network_default_muls.append(network_mul)

net_kwargs = {}
if args.network_args and i < len(args.network_args):
Expand All @@ -1550,31 +1561,32 @@ def __getattr__(self, item):
key, value = net_arg.split("=")
net_kwargs[key] = value

if args.network_weights and i < len(args.network_weights):
network_weight = args.network_weights[i]
print("load network weights from:", network_weight)
if args.network_weights is None or len(args.network_weights) <= i:
raise ValueError("No weight. Weight is required.")

if model_util.is_safetensors(network_weight) and args.network_show_meta:
from safetensors.torch import safe_open
network_weight = args.network_weights[i]
print("load network weights from:", network_weight)

with safe_open(network_weight, framework="pt") as f:
metadata = f.metadata()
if metadata is not None:
print(f"metadata for: {network_weight}: {metadata}")
if model_util.is_safetensors(network_weight) and args.network_show_meta:
from safetensors.torch import safe_open

network, weights_sd = imported_module.create_network_from_weights(
network_mul, network_weight, vae, [text_encoder1, text_encoder2], unet, for_inference=True, **net_kwargs
)
else:
raise ValueError("No weight. Weight is required.")
with safe_open(network_weight, framework="pt") as f:
metadata = f.metadata()
if metadata is not None:
print(f"metadata for: {network_weight}: {metadata}")

network, weights_sd = imported_module.create_network_from_weights(
network_mul, network_weight, vae, [text_encoder1, text_encoder2], unet, for_inference=True, **net_kwargs
)
if network is None:
return

mergeable = network.is_mergeable()
if args.network_merge and not mergeable:
if network_merge and not mergeable:
print("network is not mergiable. ignore merge option.")

if not args.network_merge or not mergeable:
if not mergeable or i >= network_merge:
# not merging
network.apply_to([text_encoder1, text_encoder2], unet)
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
print(f"weights are loaded: {info}")
Expand All @@ -1588,6 +1600,7 @@ def __getattr__(self, item):
network.backup_weights()

networks.append(network)
network_default_muls.append(network_mul)
else:
network.merge_to([text_encoder1, text_encoder2], unet, weights_sd, dtype, device)

Expand Down Expand Up @@ -1864,9 +1877,18 @@ def resize_images(imgs, size):

size = None
for i, network in enumerate(networks):
if i < 3:
if (i < 3 and args.network_regional_mask_max_color_codes is None) or i < args.network_regional_mask_max_color_codes:
np_mask = np.array(mask_images[0])
np_mask = np_mask[:, :, i]

if args.network_regional_mask_max_color_codes:
# カラーコードでマスクを指定する
ch0 = (i + 1) & 1
ch1 = ((i + 1) >> 1) & 1
ch2 = ((i + 1) >> 2) & 1
np_mask = np.all(np_mask == np.array([ch0, ch1, ch2]) * 255, axis=2)
np_mask = np_mask.astype(np.uint8) * 255
else:
np_mask = np_mask[:, :, i]
size = np_mask.shape
else:
np_mask = np.full(size, 255, dtype=np.uint8)
Expand Down Expand Up @@ -2615,10 +2637,19 @@ def setup_parser() -> argparse.ArgumentParser:
"--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数"
)
parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する")
parser.add_argument(
"--network_merge_n_models", type=int, default=None, help="merge this number of networks / この数だけネットワークをマージする"
)
parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする")
parser.add_argument(
"--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する"
)
parser.add_argument(
"--network_regional_mask_max_color_codes",
type=int,
default=None,
help="max color codes for regional mask (default is None, mask by channel) / regional maskの最大色数(デフォルトはNoneでチャンネルごとのマスク)",
)
parser.add_argument(
"--textual_inversion_embeddings",
type=str,
Expand Down

0 comments on commit 2a23713

Please sign in to comment.