diff --git a/configure.py b/configure.py index da3c3d94..7adf84da 100644 --- a/configure.py +++ b/configure.py @@ -789,6 +789,26 @@ def configure_env(): "Enter the path to your dataset. This should be a directory containing images and text files for their caption. For reliability, use an absolute (full) path, beginning with a '/'", "/datasets/my-dataset", ) + dataset_caption_strategy = prompt_user( + ( + "How should the dataloader handle captions?" + "\n-> 'filename' will use the names of your image files as the caption" + "\n-> 'textfile' requires a image.txt file to go next to your image.png file" + "\n-> 'instanceprompt' will just use one trigger phrase for all images" + "\n" + "\n(Options: filename, textfile, instanceprompt)" + ), + "textfile", + ) + if dataset_caption_strategy not in ["filename", "textfile", "instanceprompt"]: + print(f"Invalid caption strategy: {dataset_caption_strategy}") + dataset_caption_strategy = "textfile" + dataset_instance_prompt = None + if "instanceprompt" in dataset_caption_strategy: + dataset_instance_prompt = prompt_user( + "Enter the instance_prompt you want to use for all images in this dataset", + "CatchPhrase", + ) dataset_repeats = int( prompt_user( "How many times do you want to repeat each image in the dataset?", 10 @@ -818,6 +838,9 @@ def configure_env(): dataset["maximum_image_size"] = dataset["resolution"] dataset["target_downsample_size"] = dataset["resolution"] dataset["id"] = dataset["id"].replace("PLACEHOLDER", dataset_id) + if dataset_instance_prompt: + dataset["instance_prompt"] = dataset_instance_prompt + dataset["caption_strategy"] = dataset_caption_strategy print("Dataloader configuration:") print(default_local_configuration)