From 4477116a64bb6c363d0fd9fbf3e21bb813548dfe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Sat, 20 Apr 2024 21:26:09 +0800 Subject: [PATCH 1/2] fix train controlnet --- library/train_util.py | 4 ++-- requirements.txt | 1 + train_controlnet.py | 8 ++++++-- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 15c23f3cc..ecf3345fb 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1982,8 +1982,8 @@ def make_buckets(self): self.bucket_manager = self.dreambooth_dataset_delegate.bucket_manager self.buckets_indices = self.dreambooth_dataset_delegate.buckets_indices - def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): - return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process) + def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, cache_file_suffix=".npz", divisor=8): + return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, cache_file_suffix, divisor) def __len__(self): return self.dreambooth_dataset_delegate.__len__() diff --git a/requirements.txt b/requirements.txt index e99775b8a..9495dab2a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,6 +17,7 @@ easygui==0.98.3 toml==0.10.2 voluptuous==0.13.1 huggingface-hub==0.20.1 +omegaconf==2.3.0 # for Image utils imagesize==1.4.1 # for BLIP captioning diff --git a/train_controlnet.py b/train_controlnet.py index f4c94e8d9..763041aa6 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -5,7 +5,7 @@ import random import time from multiprocessing import Value -from types import SimpleNamespace +from omegaconf import OmegaConf import toml from tqdm import tqdm @@ -148,8 +148,10 @@ def train(args): "in_channels": 4, "layers_per_block": 2, "mid_block_scale_factor": 1, + "mid_block_type": "UNetMidBlock2DCrossAttn", "norm_eps": 1e-05, "norm_num_groups": 32, + "num_attention_heads": [5, 10, 20, 20], "num_class_embeds": None, "only_cross_attention": False, "out_channels": 4, @@ -179,8 +181,10 @@ def train(args): "in_channels": 4, "layers_per_block": 2, "mid_block_scale_factor": 1, + "mid_block_type": "UNetMidBlock2DCrossAttn", "norm_eps": 1e-05, "norm_num_groups": 32, + "num_attention_heads": 8, "out_channels": 4, "sample_size": 64, "up_block_types": ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"], @@ -193,7 +197,7 @@ def train(args): "resnet_time_scale_shift": "default", "projection_class_embeddings_input_dim": None, } - unet.config = SimpleNamespace(**unet.config) + unet.config = OmegaConf.create(unet.config) controlnet = ControlNetModel.from_unet(unet) From 5cb145d13bd9fae307a8766f4088b95f01492580 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Sat, 20 Apr 2024 21:56:24 +0800 Subject: [PATCH 2/2] Update train_util.py --- library/train_util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index ecf3345fb..15c23f3cc 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1982,8 +1982,8 @@ def make_buckets(self): self.bucket_manager = self.dreambooth_dataset_delegate.bucket_manager self.buckets_indices = self.dreambooth_dataset_delegate.buckets_indices - def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, cache_file_suffix=".npz", divisor=8): - return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, cache_file_suffix, divisor) + def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): + return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process) def __len__(self): return self.dreambooth_dataset_delegate.__len__()