Skip to content
This repository has been archived by the owner on Oct 22, 2023. It is now read-only.

Add caption shuffling #119

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 41 additions & 15 deletions scripts/configuration_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,7 @@ def create_default_variables(self):
self.save_sample_controlled_seed = []
self.delete_checkpoints_when_full_drive = True
self.use_image_names_as_captions = True
self.shuffle_captions = False
self.use_offset_noise = False
self.offset_noise_weight = 0.1
self.num_samples_to_generate = 1
Expand Down Expand Up @@ -1201,6 +1202,10 @@ def dreambooth_mode(self):
self.use_image_names_as_captions_checkbox.configure(state='disabled')
self.use_image_names_as_captions_var.set(0)
#self.use_image_names_as_captions_checkbox.set(0)
self.shuffle_captions_label.configure(state='disabled')
self.shuffle_captions_checkbox.configure(state='disabled')
self.shuffle_captions_var.set(0)
#self.shuffle_captions_checkbox.set(0)
self.add_class_images_to_dataset_checkbox.configure(state='disabled')
self.add_class_images_to_dataset_label.configure(state='disabled')
self.add_class_images_to_dataset_var.set(0)
Expand All @@ -1215,10 +1220,13 @@ def fine_tune_mode(self):
self.use_text_files_as_captions_label.configure(state='normal')
self.use_image_names_as_captions_label.configure(state='normal')
self.use_image_names_as_captions_checkbox.configure(state='normal')
self.shuffle_captions_label.configure(state='normal')
self.shuffle_captions_checkbox.configure(state='normal')
self.add_class_images_to_dataset_checkbox.configure(state='normal')
self.add_class_images_to_dataset_label.configure(state='normal')
self.use_text_files_as_captions_var.set(1)
self.use_image_names_as_captions_var.set(1)
self.shuffle_captions_var.set(0)
self.add_class_images_to_dataset_var.set(0)
except:
pass
Expand Down Expand Up @@ -1629,39 +1637,49 @@ def create_dataset_settings_widgets(self):
# create checkbox
self.use_image_names_as_captions_checkbox = ctk.CTkSwitch(self.dataset_frame_subframe, variable=self.use_image_names_as_captions_var)
self.use_image_names_as_captions_checkbox.grid(row=2, column=1, sticky="nsew")
# create shuffle captions checkbox
self.shuffle_captions_var = tk.IntVar()
self.shuffle_captions_var.set(self.shuffle_captions)
# create label
self.shuffle_captions_label = ctk.CTkLabel(self.dataset_frame_subframe, text="Shuffle Captions")
shuffle_captions_label_ttp = CreateToolTip(self.shuffle_captions_label, "Randomize the order of tags in a caption. Tags are separated by ','. Used for training with booru-style captions.")
self.shuffle_captions_label.grid(row=3, column=0, sticky="nsew")
# create checkbox
self.shuffle_captions_checkbox = ctk.CTkSwitch(self.dataset_frame_subframe, variable=self.shuffle_captions_var)
self.shuffle_captions_checkbox.grid(row=3, column=1, sticky="nsew")
# create auto balance dataset checkbox
self.auto_balance_dataset_var = tk.IntVar()
self.auto_balance_dataset_var.set(self.auto_balance_concept_datasets)
# create label
self.auto_balance_dataset_label = ctk.CTkLabel(self.dataset_frame_subframe, text="Auto Balance Dataset")
auto_balance_dataset_label_ttp = CreateToolTip(self.auto_balance_dataset_label, "Will use the concept with the least amount of images to balance the dataset by removing images from the other concepts.")
self.auto_balance_dataset_label.grid(row=3, column=0, sticky="nsew")
self.auto_balance_dataset_label.grid(row=4, column=0, sticky="nsew")
# create checkbox
self.auto_balance_dataset_checkbox = ctk.CTkSwitch(self.dataset_frame_subframe, variable=self.auto_balance_dataset_var)
self.auto_balance_dataset_checkbox.grid(row=3, column=1, sticky="nsew")
self.auto_balance_dataset_checkbox.grid(row=4, column=1, sticky="nsew")
#create add class images to dataset checkbox
self.add_class_images_to_dataset_var = tk.IntVar()
self.add_class_images_to_dataset_var.set(self.add_class_images_to_training)
#create label
self.add_class_images_to_dataset_label = ctk.CTkLabel(self.dataset_frame_subframe, text="Add Class Images to Dataset")
add_class_images_to_dataset_label_ttp = CreateToolTip(self.add_class_images_to_dataset_label, "Will add class images without prior preservation to the dataset.")
self.add_class_images_to_dataset_label.grid(row=4, column=0, sticky="nsew")
self.add_class_images_to_dataset_label.grid(row=5, column=0, sticky="nsew")
#create checkbox
self.add_class_images_to_dataset_checkbox = ctk.CTkSwitch(self.dataset_frame_subframe, variable=self.add_class_images_to_dataset_var)
self.add_class_images_to_dataset_checkbox.grid(row=4, column=1, sticky="nsew")
self.add_class_images_to_dataset_checkbox.grid(row=5, column=1, sticky="nsew")
#create number of class images entry
self.number_of_class_images_label = ctk.CTkLabel(self.dataset_frame_subframe, text="Number of Class Images")
number_of_class_images_label_ttp = CreateToolTip(self.number_of_class_images_label, "The number of class images to add to the dataset, if they don't exist in the class directory they will be generated.")
self.number_of_class_images_label.grid(row=5, column=0, sticky="nsew")
self.number_of_class_images_label.grid(row=6, column=0, sticky="nsew")
self.number_of_class_images_entry = ctk.CTkEntry(self.dataset_frame_subframe)
self.number_of_class_images_entry.grid(row=5, column=1, sticky="nsew")
self.number_of_class_images_entry.grid(row=6, column=1, sticky="nsew")
self.number_of_class_images_entry.insert(0, self.num_class_images)
#create dataset repeat entry
self.dataset_repeats_label = ctk.CTkLabel(self.dataset_frame_subframe, text="Dataset Repeats")
dataset_repeat_label_ttp = CreateToolTip(self.dataset_repeats_label, "The number of times to repeat the dataset, this will increase the number of images in the dataset.")
self.dataset_repeats_label.grid(row=6, column=0, sticky="nsew")
self.dataset_repeats_label.grid(row=7, column=0, sticky="nsew")
self.dataset_repeats_entry = ctk.CTkEntry(self.dataset_frame_subframe)
self.dataset_repeats_entry.grid(row=6, column=1, sticky="nsew")
self.dataset_repeats_entry.grid(row=7, column=1, sticky="nsew")
self.dataset_repeats_entry.insert(0, self.dataset_repeats)

#add use_aspect_ratio_bucketing checkbox
Expand All @@ -1670,10 +1688,10 @@ def create_dataset_settings_widgets(self):
#create label
self.use_aspect_ratio_bucketing_label = ctk.CTkLabel(self.dataset_frame_subframe, text="Use Aspect Ratio Bucketing")
use_aspect_ratio_bucketing_label_ttp = CreateToolTip(self.use_aspect_ratio_bucketing_label, "Will use aspect ratio bucketing, may improve aspect ratio generations.")
self.use_aspect_ratio_bucketing_label.grid(row=7, column=0, sticky="nsew")
self.use_aspect_ratio_bucketing_label.grid(row=8, column=0, sticky="nsew")
#create checkbox
self.use_aspect_ratio_bucketing_checkbox = ctk.CTkSwitch(self.dataset_frame_subframe, variable=self.use_aspect_ratio_bucketing_var)
self.use_aspect_ratio_bucketing_checkbox.grid(row=7, column=1, sticky="nsew")
self.use_aspect_ratio_bucketing_checkbox.grid(row=8, column=1, sticky="nsew")
#do something on checkbox click
self.use_aspect_ratio_bucketing_checkbox.bind("<Button-1>", self.aspect_ratio_mode_toggles)

Expand All @@ -1682,17 +1700,17 @@ def create_dataset_settings_widgets(self):
self.aspect_ratio_bucketing_mode_var.set(self.aspect_ratio_bucketing_mode)
self.aspect_ratio_bucketing_mode_label = ctk.CTkLabel(self.dataset_frame_subframe, text="Aspect Ratio Bucketing Mode")
aspect_ratio_bucketing_mode_label_ttp = CreateToolTip(self.aspect_ratio_bucketing_mode_label, "Select what the Auto Bucketing will do in case the bucket doesn't match the batch size, dynamic will choose the least amount of adding/removing of images per bucket.")
self.aspect_ratio_bucketing_mode_label.grid(row=8, column=0, sticky="nsew")
self.aspect_ratio_bucketing_mode_label.grid(row=9, column=0, sticky="nsew")
self.aspect_ratio_bucketing_mode_option_menu = ctk.CTkOptionMenu(self.dataset_frame_subframe, variable=self.aspect_ratio_bucketing_mode_var, values=['Dynamic Fill', 'Drop Fill', 'Duplicate Fill'])
self.aspect_ratio_bucketing_mode_option_menu.grid(row=8, column=1, sticky="nsew")
self.aspect_ratio_bucketing_mode_option_menu.grid(row=9, column=1, sticky="nsew")
#option menu to select dynamic bucketing mode (if enabled)
self.dynamic_bucketing_mode_var = tk.StringVar()
self.dynamic_bucketing_mode_var.set(self.dynamic_bucketing_mode)
self.dynamic_bucketing_mode_label = ctk.CTkLabel(self.dataset_frame_subframe, text="Dynamic Preference")
dynamic_bucketing_mode_label_ttp = CreateToolTip(self.dynamic_bucketing_mode_label, "If you're using dynamic mode, choose what you prefer in the case that dropping and duplicating are the same amount of images.")
self.dynamic_bucketing_mode_label.grid(row=9, column=0, sticky="nsew")
self.dynamic_bucketing_mode_label.grid(row=10, column=0, sticky="nsew")
self.dynamic_bucketing_mode_option_menu = ctk.CTkOptionMenu(self.dataset_frame_subframe, variable=self.dynamic_bucketing_mode_var, values=['Duplicate', 'Drop'])
self.dynamic_bucketing_mode_option_menu.grid(row=9, column=1, sticky="nsew")
self.dynamic_bucketing_mode_option_menu.grid(row=10, column=1, sticky="nsew")
#add shuffle dataset per epoch checkbox
self.shuffle_dataset_per_epoch_var = tk.IntVar()
self.shuffle_dataset_per_epoch_var.set(self.shuffle_dataset_per_epoch)
Expand Down Expand Up @@ -3067,6 +3085,7 @@ def save_config(self, config_file=None):
configure["with_prior_loss_preservation"] = self.with_prior_loss_preservation_var.get()
configure["prior_loss_preservation_weight"] = self.prior_loss_preservation_weight_entry.get()
configure["use_image_names_as_captions"] = self.use_image_names_as_captions_var.get()
configure["shuffle_captions"] = self.shuffle_captions_var.get()
configure["auto_balance_concept_datasets"] = self.auto_balance_dataset_var.get()
configure["add_class_images_to_dataset"] = self.add_class_images_to_dataset_var.get()
configure["number_of_class_images"] = self.number_of_class_images_entry.get()
Expand Down Expand Up @@ -3201,6 +3220,7 @@ def load_config(self,file_name=None):
self.prior_loss_preservation_weight_entry.delete(0, tk.END)
self.prior_loss_preservation_weight_entry.insert(0, configure["prior_loss_preservation_weight"])
self.use_image_names_as_captions_var.set(configure["use_image_names_as_captions"])
self.shuffle_captions_var.set(configure["shuffle_captions"])
self.auto_balance_dataset_var.set(configure["auto_balance_concept_datasets"])
self.add_class_images_to_dataset_var.set(configure["add_class_images_to_dataset"])
self.number_of_class_images_entry.delete(0, tk.END)
Expand Down Expand Up @@ -3296,6 +3316,7 @@ def process_inputs(self,export=None):
self.with_prior_loss_preservation = self.with_prior_loss_preservation_var.get()
self.prior_loss_preservation_weight = self.prior_loss_preservation_weight_entry.get()
self.use_image_names_as_captions = self.use_image_names_as_captions_var.get()
self.shuffle_captions = self.shuffle_captions_var.get()
self.auto_balance_concept_datasets = self.auto_balance_dataset_var.get()
self.add_class_images_to_dataset = self.add_class_images_to_dataset_var.get()
self.number_of_class_images = self.number_of_class_images_entry.get()
Expand Down Expand Up @@ -3376,7 +3397,7 @@ def process_inputs(self,export=None):
#check if resolution is the same
try:
#try because I keep adding stuff to the json file and it may error out for peeps
if self.last_run["resolution"] != self.resolution or self.use_text_files_as_captions != self.last_run['use_text_files_as_captions'] or self.last_run['dataset_repeats'] != self.dataset_repeats or self.last_run["batch_size"] != self.batch_size or self.last_run["train_text_encoder"] != self.train_text_encoder or self.last_run["use_image_names_as_captions"] != self.use_image_names_as_captions or self.last_run["auto_balance_concept_datasets"] != self.auto_balance_concept_datasets or self.last_run["add_class_images_to_dataset"] != self.add_class_images_to_dataset or self.last_run["number_of_class_images"] != self.number_of_class_images or self.last_run["aspect_ratio_bucketing"] != self.use_aspect_ratio_bucketing or self.last_run["masked_training"] != self.masked_training:
if self.last_run["resolution"] != self.resolution or self.use_text_files_as_captions != self.last_run['use_text_files_as_captions'] or self.last_run['dataset_repeats'] != self.dataset_repeats or self.last_run["batch_size"] != self.batch_size or self.last_run["train_text_encoder"] != self.train_text_encoder or self.last_run["use_image_names_as_captions"] != self.use_image_names_as_captions or self.last_run["shuffle_captions"] != self.shuffle_captions or self.last_run["auto_balance_concept_datasets"] != self.auto_balance_concept_datasets or self.last_run["add_class_images_to_dataset"] != self.add_class_images_to_dataset or self.last_run["number_of_class_images"] != self.number_of_class_images or self.last_run["aspect_ratio_bucketing"] != self.use_aspect_ratio_bucketing or self.last_run["masked_training"] != self.masked_training:
self.regenerate_latent_cache = True
#show message

Expand Down Expand Up @@ -3624,6 +3645,11 @@ def process_inputs(self,export=None):
batBase += ' --use_image_names_as_captions'
else:
batBase += f' "--use_image_names_as_captions" '
if self.shuffle_captions == True:
if export == 'Linux':
batBase += ' --shuffle_captions'
else:
batBase += f' "--shuffle_captions" '
if self.use_offset_noise == True:
if export == 'Linux':
batBase += f' --with_offset_noise'
Expand Down
18 changes: 16 additions & 2 deletions scripts/dataloaders_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ def __init__(self,
resolution=512,
center_crop=False,
use_image_names_as_captions=True,
shuffle_captions=False,
add_class_images_to_dataset=None,
balance_datasets=False,
crop_jitter=20,
Expand All @@ -342,6 +343,7 @@ def __init__(self,
self.batch_size = batch_size
self.concepts_list = concepts_list
self.use_image_names_as_captions = use_image_names_as_captions
self.shuffle_captions = shuffle_captions
self.num_train_images = 0
self.num_reg_images = 0
self.image_train_items = []
Expand Down Expand Up @@ -447,16 +449,22 @@ def __get_image_for_trainer(self,image_train_item,debug_level=0,class_img=False)
image_train_tmp = image_train_item.hydrate(crop=False, save=0, crop_jitter=self.crop_jitter)
image_train_tmp_image = Image.fromarray(self.normalize8(image_train_tmp.image)).convert("RGB")

instance_prompt = image_train_tmp.caption
if self.shuffle_captions:
caption_parts = instance_prompt.split(",")
random.shuffle(caption_parts)
instance_prompt = ",".join(caption_parts)

example["instance_images"] = self.image_transforms(image_train_tmp_image)
if image_train_tmp.mask is not None:
image_train_tmp_mask = Image.fromarray(self.normalize8(image_train_tmp.mask)).convert("L")
example["mask"] = self.mask_transforms(image_train_tmp_mask)
if self.model_variant == 'depth2img':
image_train_tmp_depth = Image.fromarray(self.normalize8(image_train_tmp.extra)).convert("L")
example["instance_depth_images"] = self.depth_image_transforms(image_train_tmp_depth)
#print(image_train_tmp.caption)
#print(instance_prompt)
example["instance_prompt_ids"] = self.tokenizer(
image_train_tmp.caption,
instance_prompt,
padding="do_not_pad",
truncation=True,
max_length=self.tokenizer.model_max_length,
Expand Down Expand Up @@ -1051,6 +1059,7 @@ def __init__(
center_crop=False,
num_class_images=None,
use_image_names_as_captions=False,
shuffle_captions=False,
repeats=1,
use_text_files_as_captions=False,
seed=555,
Expand All @@ -1060,6 +1069,7 @@ def __init__(
load_mask=None,
):
self.use_image_names_as_captions = use_image_names_as_captions
self.shuffle_captions = shuffle_captions
self.size = size
self.center_crop = center_crop
self.tokenizer = tokenizer
Expand Down Expand Up @@ -1229,6 +1239,10 @@ def __getitem__(self, index):
instance_prompt = f.readline().rstrip()
f.close()

if self.shuffle_captions:
caption_parts = instance_prompt.split(",")
random.shuffle(caption_parts)
instance_prompt = ",".join(caption_parts)

#print('identifier: ' + instance_prompt)
instance_image = instance_image.convert("RGB")
Expand Down
3 changes: 3 additions & 0 deletions scripts/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ def parse_args():
parser.add_argument('--append_sample_controlled_seed_action', action='append')
parser.add_argument('--add_sample_prompt', type=str, action='append')
parser.add_argument('--use_image_names_as_captions', default=False, action="store_true")
parser.add_argument('--shuffle_captions', default=False, action="store_true")
parser.add_argument("--masked_training", default=False, required=False, action='store_true', help="Whether to mask parts of the image during training")
parser.add_argument("--normalize_masked_area_loss", default=False, required=False, action='store_true', help="Normalize the loss, to make it independent of the size of the masked area")
parser.add_argument("--unmasked_probability", type=float, default=1, required=False, help="Probability of training a step without a mask")
Expand Down Expand Up @@ -612,6 +613,7 @@ def main():
train_dataset = AutoBucketing(
concepts_list=args.concepts_list,
use_image_names_as_captions=args.use_image_names_as_captions,
shuffle_captions=args.shuffle_captions,
batch_size=args.train_batch_size,
tokenizer=tokenizer,
add_class_images_to_dataset=args.add_class_images_to_dataset,
Expand All @@ -637,6 +639,7 @@ def main():
center_crop=args.center_crop,
num_class_images=args.num_class_images,
use_image_names_as_captions=args.use_image_names_as_captions,
shuffle_captions=args.shuffle_captions,
repeats=args.dataset_repeats,
use_text_files_as_captions=args.use_text_files_as_captions,
seed = args.seed,
Expand Down