Skip to content

Commit

Permalink
remove dependency for omegaconf #ref 1284
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed May 19, 2024
1 parent de0e0b9 commit e4d9e3c
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 8 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser

- Training scripts can now output training settings to wandb or Tensor Board logs. Specify the `--log_config` option. PR [#1285](https://github.com/kohya-ss/sd-scripts/pull/1285) Thanks to ccharest93, plucked, rockerBOO, and VelocityRa!
- Some settings, such as API keys and directory specifications, are not output due to security issues.

- The ControlNet training script `train_controlnet.py` for SD1.5/2.x was not working, but it has been fixed. PR [#1284](https://github.com/kohya-ss/sd-scripts/pull/1284) Thanks to sdbds!

- An option `--disable_mmap_load_safetensors` is added to disable memory mapping when loading the model's .safetensors in SDXL. PR [#1266](https://github.com/kohya-ss/sd-scripts/pull/1266) Thanks to Zovjsra!
- It seems that the model file loading is faster in the WSL environment etc.
Expand Down Expand Up @@ -215,6 +217,8 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) Thanks to frodo821!
- 各学習スクリプトで学習設定を wandb や Tensor Board などのログに出力できるようになりました。`--log_config` オプションを指定してください。PR [#1285](https://github.com/kohya-ss/sd-scripts/pull/1285) ccharest93 氏、plucked 氏、rockerBOO 氏および VelocityRa 氏に感謝します。
- API キーや各種ディレクトリ指定など、一部の設定はセキュリティ上の問題があるため出力されません。

- SD1.5/2.x 用の ControlNet 学習スクリプト `train_controlnet.py` が動作しなくなっていたのが修正されました。PR [#1284](https://github.com/kohya-ss/sd-scripts/pull/1284) sdbds 氏に感謝します。

- SDXL でモデルの .safetensors を読み込む際にメモリマッピングを無効化するオプション `--disable_mmap_load_safetensors` が追加されました。PR [#1266](https://github.com/kohya-ss/sd-scripts/pull/1266) Zovjsra 氏に感謝します。
- WSL 環境等でモデルファイルの読み込みが高速化されるようです。
- `sdxl_train.py``sdxl_train_network.py``sdxl_train_textual_inversion.py``sdxl_train_control_net_lllite.py` で使用可能です。
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ easygui==0.98.3
toml==0.10.2
voluptuous==0.13.1
huggingface-hub==0.20.1
omegaconf==2.3.0
# for Image utils
imagesize==1.4.1
# for BLIP captioning
Expand Down
38 changes: 31 additions & 7 deletions train_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@
import random
import time
from multiprocessing import Value
from omegaconf import OmegaConf

# from omegaconf import OmegaConf
import toml

from tqdm import tqdm

import torch
from library import deepspeed_utils
from library.device_utils import init_ipex, clean_memory_on_device

init_ipex()

from torch.nn.parallel import DistributedDataParallel as DDP
Expand Down Expand Up @@ -197,7 +199,23 @@ def train(args):
"resnet_time_scale_shift": "default",
"projection_class_embeddings_input_dim": None,
}
unet.config = OmegaConf.create(unet.config)
# unet.config = OmegaConf.create(unet.config)

# make unet.config iterable and accessible by attribute
class CustomConfig:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)

def __getattr__(self, name):
if name in self.__dict__:
return self.__dict__[name]
else:
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")

def __contains__(self, name):
return name in self.__dict__

unet.config = CustomConfig(**unet.config)

controlnet = ControlNetModel.from_unet(unet)

Expand Down Expand Up @@ -230,7 +248,7 @@ def train(args):
)
vae.to("cpu")
clean_memory_on_device(accelerator.device)

accelerator.wait_for_everyone()

if args.gradient_checkpointing:
Expand All @@ -239,7 +257,7 @@ def train(args):
# 学習に必要なクラスを準備する
accelerator.print("prepare optimizer, data loader etc.")

trainable_params = controlnet.parameters()
trainable_params = list(controlnet.parameters())

_, _, optimizer = train_util.get_optimizer(args, trainable_params)

Expand Down Expand Up @@ -348,7 +366,9 @@ def train(args):
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
"controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs
"controlnet_train" if args.log_tracker_name is None else args.log_tracker_name,
config=train_util.get_sanitized_config_or_none(args),
init_kwargs=init_kwargs,
)

loss_recorder = train_util.LossRecorder()
Expand Down Expand Up @@ -424,7 +444,9 @@ def remove_model(old_ckpt_name):
)

# Sample a random timestep for each image
timesteps, huber_c = train_util.get_timesteps_and_huber_c(args, 0, noise_scheduler.config.num_train_timesteps, noise_scheduler, b_size, latents.device)
timesteps, huber_c = train_util.get_timesteps_and_huber_c(
args, 0, noise_scheduler.config.num_train_timesteps, noise_scheduler, b_size, latents.device
)

# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
Expand Down Expand Up @@ -456,7 +478,9 @@ def remove_model(old_ckpt_name):
else:
target = noise

loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)
loss = loss.mean([1, 2, 3])

loss_weights = batch["loss_weights"] # 各sampleごとのweight
Expand Down

0 comments on commit e4d9e3c

Please sign in to comment.