Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sdxl refiner: ability to validate on images using 20% denoise strength | deepfloyd: stage II eval fixes | factory should sleep when waiting for text embed write #402

Merged
merged 3 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions helpers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,6 +857,15 @@ def parse_args(input_args=None):
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
),
)
parser.add_argument(
"--validation_using_datasets",
action="store_true",
default=None,
help=(
"When set, validation will use images sampled randomly from each dataset for validation."
" Be mindful of privacy issues when publishing training data to the internet."
),
)
parser.add_argument(
"--webhook_config",
type=str,
Expand Down Expand Up @@ -1418,6 +1427,7 @@ def parse_args(input_args=None):
"DeepFloyd Stage II requires a resolution of at least 256. Setting to 256."
)
args.resolution = 256
args.aspect_bucket_alignment = 64
args.resolution_type = "pixel"

if (
Expand Down
3 changes: 0 additions & 3 deletions helpers/caching/sdxl_embeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,9 +591,6 @@ def compute_embeddings_for_legacy_prompts(
self.process_write_batches = False

if not return_concat:
logger.info(
"Not returning embeds, since we just concatenated a whackload of them."
)
del prompt_embeds_all
gc.collect()
return
Expand Down
4 changes: 3 additions & 1 deletion helpers/data_backend/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from helpers.training.collate import collate_fn
from helpers.training.state_tracker import StateTracker

import json, os, torch, logging, io
import json, os, torch, logging, io, time

logger = logging.getLogger("DataBackendFactory")

Expand Down Expand Up @@ -233,6 +233,7 @@ def configure_multi_databackend(
raise FileNotFoundError(
f"Data backend config file {args.data_backend_config} not found."
)
logger.info(f"Loading data backend config from {args.data_backend_config}")
with open(args.data_backend_config, "r") as f:
data_backend_config = json.load(f)
if len(data_backend_config) == 0:
Expand Down Expand Up @@ -317,6 +318,7 @@ def configure_multi_databackend(
init_backend["text_embed_cache"].compute_embeddings_for_prompts(
[""], return_concat=False, load_from_cache=False
)
time.sleep(5)
accelerator.wait_for_everyone()
if args.caption_dropout_probability == 0.0:
logger.warning(
Expand Down
11 changes: 7 additions & 4 deletions helpers/legacy/validation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import logging, os, torch, numpy as np
import logging, os, time, torch, numpy as np
from tqdm import tqdm
from diffusers.utils import is_wandb_available
from diffusers.utils.torch_utils import is_compiled_module
Expand All @@ -24,7 +24,7 @@
logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL") or "INFO")


def deepfloyd_validation_images():
def retrieve_validation_images():
"""
From each data backend, collect the top 5 images for validation, such that
we select the same images on each startup, unless the dataset changes.
Expand Down Expand Up @@ -74,7 +74,7 @@ def prepare_validation_prompt_list(args, embed_cache):
if "deepfloyd-stage2" in StateTracker.get_args().model_type:
# Now, we prepare the DeepFloyd upscaler image inputs so that we can calculate their prompts.
# If we don't do it here, they won't be available at inference time.
validation_sample_images = deepfloyd_validation_images()
validation_sample_images = retrieve_validation_images()
if len(validation_sample_images) > 0:
StateTracker.set_validation_sample_images(validation_sample_images)
# Collect the prompts for the validation images.
Expand All @@ -84,7 +84,10 @@ def prepare_validation_prompt_list(args, embed_cache):
desc="Precomputing DeepFloyd stage 2 eval prompt embeds",
):
_, validation_prompt, _ = _validation_sample
embed_cache.compute_embeddings_for_prompts([validation_prompt])
embed_cache.compute_embeddings_for_prompts(
[validation_prompt], load_from_cache=False
)
time.sleep(5)

if args.validation_prompt_library:
# Use the SimpleTuner prompts library for validation prompts.
Expand Down
9 changes: 8 additions & 1 deletion helpers/multiaspect/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from PIL.ImageOps import exif_transpose
from helpers.training.multi_process import rank_info
from helpers.metadata.backends.base import MetadataBackend
from helpers.image_manipulation.training_sample import TrainingSample
from helpers.multiaspect.state import BucketStateManager
from helpers.data_backend.base import BaseDataBackend
from helpers.training.state_tracker import StateTracker
Expand Down Expand Up @@ -141,6 +142,10 @@ def retrieve_validation_set(self, batch_size: int):
image_path = self._yield_random_image()
image_data = self.data_backend.read_image(image_path)
image_metadata = self.metadata_backend.get_metadata_by_filepath(image_path)
training_sample = TrainingSample(
image=image_data, data_backend_id=self.id, image_metadata=image_metadata
)
training_sample.prepare()
validation_shortname = f"{self.id}_{img_idx}"
validation_prompt = PromptHandler.magic_prompt(
sampler_backend_id=self.id,
Expand All @@ -151,7 +156,9 @@ def retrieve_validation_set(self, batch_size: int):
prepend_instance_prompt=self.prepend_instance_prompt,
instance_prompt=self.instance_prompt,
)
results.append((validation_shortname, validation_prompt, image_data))
results.append(
(validation_shortname, validation_prompt, training_sample.image)
)

return results

Expand Down
64 changes: 56 additions & 8 deletions helpers/training/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from helpers.training.state_tracker import StateTracker
from helpers.sdxl.pipeline import StableDiffusionXLPipeline
from helpers.legacy.pipeline import DiffusionPipeline
from helpers.legacy.validation import retrieve_validation_images
from diffusers.pipelines import StableDiffusionXLImg2ImgPipeline
from diffusers.training_utils import EMAModel
from diffusers.schedulers import (
EulerDiscreteScheduler,
Expand Down Expand Up @@ -123,6 +125,7 @@ def __init__(
self.ema_unet = ema_unet
self.vae = vae
self.pipeline = None
self._discover_validation_input_samples()
self.validation_resolutions = (
get_validation_resolutions()
if "deepfloyd-stage2" not in args.model_type
Expand All @@ -131,9 +134,30 @@ def __init__(

self._update_state()

def _discover_validation_input_samples(self):
"""
If we have some workflow that requires image inputs for validation, we'll bind those now.

Returns:
Validation object (self)
"""
self.validation_image_inputs = None
if (
"deepfloyd-stage2" in StateTracker.get_args().model_type
or StateTracker.get_args().validation_using_datasets
):
self.validation_image_inputs = retrieve_validation_images()
# Validation inputs are in the format of a list of tuples:
# [(shortname, prompt, image), ...]
logger.info(
f"Image inputs discovered for validation: {self.validation_image_inputs}"
)

def _pipeline_cls(self):
model_type = StateTracker.get_model_type()
if model_type == "sdxl":
if StateTracker.get_args().validation_using_datasets:
return StableDiffusionXLImg2ImgPipeline
return StableDiffusionXLPipeline
elif model_type == "legacy":
if "deepfloyd-stage2" in self.args.model_type:
Expand Down Expand Up @@ -368,18 +392,33 @@ def setup_pipeline(self, validation_type):
def process_prompts(self):
"""Processes each validation prompt and logs the result."""
validation_images = {}
for shortname, prompt in zip(
self.validation_shortnames, self.validation_prompts
):
_content = zip(self.validation_shortnames, self.validation_prompts)
if self.validation_image_inputs:
# Override the pipeline inputs to be entirely based upon the validation image inputs.
_content = self.validation_image_inputs
for content in _content:
validation_input_image = None
if len(content) == 3:
shortname, prompt, validation_input_image = content
elif len(content) == 2:
shortname, prompt = content
else:
raise ValueError(
f"Validation content is not in the correct format: {content}"
)
logger.debug(f"Processing validation for prompt: {prompt}")
validation_images.update(self.validate_prompt(prompt, shortname))
validation_images.update(
self.validate_prompt(prompt, shortname, validation_input_image)
)
self._save_images(validation_images, shortname, prompt)
self._log_validations_to_webhook(validation_images, shortname, prompt)
logger.debug(f"Completed generating image: {prompt}")
self.validation_images = validation_images
self._log_validations_to_trackers(validation_images)

def validate_prompt(self, prompt, validation_shortname):
def validate_prompt(
self, prompt, validation_shortname, validation_input_image=None
):
"""Generate validation images for a single prompt."""
# Placeholder for actual image generation and logging
logger.info(f"Validating prompt: {prompt}")
Expand All @@ -393,16 +432,25 @@ def validate_prompt(self, prompt, validation_shortname):
logger.info(
f"Using a generator? {extra_validation_kwargs['generator']}"
)
if "deepfloyd-stage2" not in self.args.model_type:
validation_resolution_width, validation_resolution_height = resolution
else:
if validation_input_image is not None:
extra_validation_kwargs["image"] = validation_input_image
validation_resolution_width, validation_resolution_height = (
val * 4 for val in extra_validation_kwargs["image"].size
)
else:
validation_resolution_width, validation_resolution_height = resolution

if "deepfloyd" not in self.args.model_type:
extra_validation_kwargs["guidance_rescale"] = (
self.args.validation_guidance_rescale
)
if StateTracker.get_args().validation_using_datasets:
extra_validation_kwargs["strength"] = getattr(
self.args, "validation_strength", 0.2
)
logger.debug(
f"Set validation image denoise strength to {extra_validation_kwargs['strength']}"
)

logger.debug(
f"Processing width/height: {validation_resolution_width}x{validation_resolution_height}"
Expand Down
5 changes: 5 additions & 0 deletions train_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,11 @@ def main():
if tokenizer_1 is None:
logger.info("Seems that we are training an SDXL refiner model.")
StateTracker.is_sdxl_refiner(True)
if args.validation_using_datasets is None:
logger.warning(
f"Since we are training the SDXL refiner and --validation_using_datasets was not specified, it is now being enabled."
)
args.validation_using_datasets = True
except:
logger.warning(
"Could not load secondary tokenizer (OpenCLIP-G/14). Cannot continue."
Expand Down
Loading