Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add metadata caching for DreamBooth dataset #1206

Merged
merged 3 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/config_README-en.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ Options related to the configuration of DreamBooth subsets.
| `image_dir` | `'C:\hoge'` | - | - | o (required) |
| `caption_extension` | `".txt"` | o | o | o |
| `class_tokens` | `"sks girl"` | - | - | o |
| `cache_info` | `false` | o | o | o |
| `is_reg` | `false` | - | - | o |

Firstly, note that for `image_dir`, the path to the image files must be specified as being directly in the directory. Unlike the previous DreamBooth method, where images had to be placed in subdirectories, this is not compatible with that specification. Also, even if you name the folder something like "5_cat", the number of repeats of the image and the class name will not be reflected. If you want to set these individually, you will need to explicitly specify them using `num_repeats` and `class_tokens`.
Expand All @@ -187,6 +188,9 @@ Firstly, note that for `image_dir`, the path to the image files must be specifie
* `class_tokens`
* Sets the class tokens.
* Only used during training when a corresponding caption file does not exist. The determination of whether or not to use it is made on a per-image basis. If `class_tokens` is not specified and a caption file is not found, an error will occur.
* `cache_info`
* Specifies whether to cache the image size and caption. If not specified, it is set to `false`. The cache is saved in `metadata_cache.json` in `image_dir`.
* Caching speeds up the loading of the dataset after the first time. It is effective when dealing with thousands of images or more.
* `is_reg`
* Specifies whether the subset images are for normalization. If not specified, it is set to `false`, meaning that the images are not for normalization.

Expand Down
4 changes: 4 additions & 0 deletions docs/config_README-ja.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ DreamBooth 方式のサブセットの設定に関わるオプションです。
| `image_dir` | `‘C:\hoge’` | - | - | o(必須) |
| `caption_extension` | `".txt"` | o | o | o |
| `class_tokens` | `“sks girl”` | - | - | o |
| `cache_info` | `false` | o | o | o |
| `is_reg` | `false` | - | - | o |

まず注意点として、 `image_dir` には画像ファイルが直下に置かれているパスを指定する必要があります。従来の DreamBooth の手法ではサブディレクトリに画像を置く必要がありましたが、そちらとは仕様に互換性がありません。また、`5_cat` のようなフォルダ名にしても、画像の繰り返し回数とクラス名は反映されません。これらを個別に設定したい場合、`num_repeats` と `class_tokens` で明示的に指定する必要があることに注意してください。
Expand All @@ -183,6 +184,9 @@ DreamBooth 方式のサブセットの設定に関わるオプションです。
* `class_tokens`
* クラストークンを設定します。
* 画像に対応する caption ファイルが存在しない場合にのみ学習時に利用されます。利用するかどうかの判定は画像ごとに行います。`class_tokens` を指定しなかった場合に caption ファイルも見つからなかった場合にはエラーになります。
* `cache_info`
* 画像サイズ、キャプションをキャッシュするかどうかを指定します。指定しなかった場合は `false` になります。キャッシュは `image_dir` に `metadata_cache.json` というファイル名で保存されます。
* キャッシュを行うと、二回目以降のデータセット読み込みが高速化されます。数千枚以上の画像を扱う場合には有効です。
* `is_reg`
* サブセットの画像が正規化用かどうかを指定します。指定しなかった場合は `false` として、つまり正規化画像ではないとして扱います。

Expand Down
4 changes: 4 additions & 0 deletions library/config_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class DreamBoothSubsetParams(BaseSubsetParams):
is_reg: bool = False
class_tokens: Optional[str] = None
caption_extension: str = ".caption"
cache_info: bool = False


@dataclass
Expand All @@ -96,6 +97,7 @@ class FineTuningSubsetParams(BaseSubsetParams):
class ControlNetSubsetParams(BaseSubsetParams):
conditioning_data_dir: str = None
caption_extension: str = ".caption"
cache_info: bool = False


@dataclass
Expand Down Expand Up @@ -205,6 +207,7 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]
DB_SUBSET_ASCENDABLE_SCHEMA = {
"caption_extension": str,
"class_tokens": str,
"cache_info": bool,
}
DB_SUBSET_DISTINCT_SCHEMA = {
Required("image_dir"): str,
Expand All @@ -217,6 +220,7 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]
}
CN_SUBSET_ASCENDABLE_SCHEMA = {
"caption_extension": str,
"cache_info": bool,
}
CN_SUBSET_DISTINCT_SCHEMA = {
Required("image_dir"): str,
Expand Down
99 changes: 77 additions & 22 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from huggingface_hub import hf_hub_download
import numpy as np
from PIL import Image
import imagesize
import cv2
import safetensors.torch
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
Expand Down Expand Up @@ -410,6 +411,7 @@ def __init__(
is_reg: bool,
class_tokens: Optional[str],
caption_extension: str,
cache_info: bool,
num_repeats,
shuffle_caption,
caption_separator: str,
Expand Down Expand Up @@ -458,6 +460,7 @@ def __init__(
self.caption_extension = caption_extension
if self.caption_extension and not self.caption_extension.startswith("."):
self.caption_extension = "." + self.caption_extension
self.cache_info = cache_info

def __eq__(self, other) -> bool:
if not isinstance(other, DreamBoothSubset):
Expand Down Expand Up @@ -527,6 +530,7 @@ def __init__(
image_dir: str,
conditioning_data_dir: str,
caption_extension: str,
cache_info: bool,
num_repeats,
shuffle_caption,
caption_separator,
Expand Down Expand Up @@ -574,6 +578,7 @@ def __init__(
self.caption_extension = caption_extension
if self.caption_extension and not self.caption_extension.startswith("."):
self.caption_extension = "." + self.caption_extension
self.cache_info = cache_info

def __eq__(self, other) -> bool:
if not isinstance(other, ControlNetSubset):
Expand Down Expand Up @@ -1081,8 +1086,7 @@ def cache_text_encoder_outputs(
)

def get_image_size(self, image_path):
image = Image.open(image_path)
return image.size
return imagesize.get(image_path)

def load_image_with_face_info(self, subset: BaseSubset, image_path: str):
img = load_image(image_path)
Expand Down Expand Up @@ -1411,6 +1415,8 @@ def get_item_for_caching(self, bucket, bucket_batch_size, image_index):


class DreamBoothDataset(BaseDataset):
IMAGE_INFO_CACHE_FILE = "metadata_cache.json"

def __init__(
self,
subsets: Sequence[DreamBoothSubset],
Expand Down Expand Up @@ -1485,26 +1491,54 @@ def load_dreambooth_dir(subset: DreamBoothSubset):
logger.warning(f"not directory: {subset.image_dir}")
return [], []

img_paths = glob_images(subset.image_dir, "*")
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")

# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
captions = []
missing_captions = []
for img_path in img_paths:
cap_for_img = read_caption(img_path, subset.caption_extension, subset.enable_wildcard)
if cap_for_img is None and subset.class_tokens is None:
info_cache_file = os.path.join(subset.image_dir, self.IMAGE_INFO_CACHE_FILE)
use_cached_info_for_subset = subset.cache_info
if use_cached_info_for_subset:
logger.info(
f"using cached image info for this subset / このサブセットで、キャッシュされた画像情報を使います: {info_cache_file}"
)
if not os.path.isfile(info_cache_file):
logger.warning(
f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}"
f"image info file not found. You can ignore this warning if this is the first time to use this subset"
+ " / キャッシュファイルが見つかりませんでした。初回実行時はこの警告を無視してください: {metadata_file}"
)
captions.append("")
missing_captions.append(img_path)
else:
if cap_for_img is None:
captions.append(subset.class_tokens)
use_cached_info_for_subset = False

if use_cached_info_for_subset:
# json: {`img_path`:{"caption": "caption...", "resolution": [width, height]}, ...}
with open(info_cache_file, "r", encoding="utf-8") as f:
metas = json.load(f)
img_paths = list(metas.keys())
sizes = [meta["resolution"] for meta in metas.values()]

# we may need to check image size and existence of image files, but it takes time, so user should check it before training
else:
img_paths = glob_images(subset.image_dir, "*")
sizes = [None] * len(img_paths)

logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")

if use_cached_info_for_subset:
captions = [meta["caption"] for meta in metas.values()]
missing_captions = [img_path for img_path, caption in zip(img_paths, captions) if caption is None or caption == ""]
else:
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
captions = []
missing_captions = []
for img_path in img_paths:
cap_for_img = read_caption(img_path, subset.caption_extension, subset.enable_wildcard)
if cap_for_img is None and subset.class_tokens is None:
logger.warning(
f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}"
)
captions.append("")
missing_captions.append(img_path)
else:
captions.append(cap_for_img)
if cap_for_img is None:
captions.append(subset.class_tokens)
missing_captions.append(img_path)
else:
captions.append(cap_for_img)

self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録

Expand All @@ -1521,7 +1555,19 @@ def load_dreambooth_dir(subset: DreamBoothSubset):
logger.warning(missing_caption + f"... and {remaining_missing_captions} more")
break
logger.warning(missing_caption)
return img_paths, captions

if not use_cached_info_for_subset and subset.cache_info:
logger.info(f"cache image info for / 画像情報をキャッシュします : {info_cache_file}")
sizes = [self.get_image_size(img_path) for img_path in tqdm(img_paths, desc="get image size")]
matas = {}
for img_path, caption, size in zip(img_paths, captions, sizes):
matas[img_path] = {"caption": caption, "resolution": list(size)}
with open(info_cache_file, "w", encoding="utf-8") as f:
json.dump(matas, f, ensure_ascii=False, indent=2)
logger.info(f"cache image info done for / 画像情報を出力しました : {info_cache_file}")

# if sizes are not set, image size will be read in make_buckets
return img_paths, captions, sizes

logger.info("prepare images.")
num_train_images = 0
Expand All @@ -1540,7 +1586,7 @@ def load_dreambooth_dir(subset: DreamBoothSubset):
)
continue

img_paths, captions = load_dreambooth_dir(subset)
img_paths, captions, sizes = load_dreambooth_dir(subset)
if len(img_paths) < 1:
logger.warning(
f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します"
Expand All @@ -1552,8 +1598,10 @@ def load_dreambooth_dir(subset: DreamBoothSubset):
else:
num_train_images += subset.num_repeats * len(img_paths)

for img_path, caption in zip(img_paths, captions):
for img_path, caption, size in zip(img_paths, captions, sizes):
info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path)
if size is not None:
info.image_size = size
if subset.is_reg:
reg_infos.append((info, subset))
else:
Expand Down Expand Up @@ -1842,7 +1890,8 @@ def __init__(
subset.image_dir,
False,
None,
subset.caption_extension,
subset.caption_extension,
subset.cache_info,
subset.num_repeats,
subset.shuffle_caption,
subset.caption_separator,
Expand Down Expand Up @@ -3384,6 +3433,12 @@ def add_dataset_arguments(
parser.add_argument(
"--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ"
)
parser.add_argument(
"--cache_info",
action="store_true",
help="cache meta information (caption and image size) for faster dataset loading. only available for DreamBooth"
+ " / メタ情報(キャプションとサイズ)をキャッシュしてデータセット読み込みを高速化する。DreamBooth方式のみ有効",
)
parser.add_argument(
"--shuffle_caption", action="store_true", help="shuffle separated caption / 区切られたcaptionの各要素をshuffleする"
)
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ easygui==0.98.3
toml==0.10.2
voluptuous==0.13.1
huggingface-hub==0.20.1
# for Image utils
imagesize==1.4.1
# for BLIP captioning
# requests==2.28.2
# timm==0.6.12
Expand Down
7 changes: 1 addition & 6 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,14 @@
import torch
from library.device_utils import init_ipex, clean_memory_on_device


init_ipex()

from torch.nn.parallel import DistributedDataParallel as DDP

from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from library import deepspeed_utils, model_util

import library.train_util as train_util
from library.train_util import (
DreamBoothDataset,
)
from library.train_util import DreamBoothDataset
import library.config_util as config_util
from library.config_util import (
ConfigSanitizer,
Expand Down