Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

parquet: fix aspect bucketing | json: mild optimisation | llava: add 1.6 support | pillow: fix deprecations #350

Merged
merged 18 commits into from
Apr 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion helpers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,13 @@ def parse_args(input_args=None):
help="Whether or not to use stochastic bf16 in Adam. Currently the only supported optimizer.",
)
parser.add_argument(
"--max_grad_norm", default=1.0, type=float, help="Max gradient norm."
"--max_grad_norm",
default=2.0,
type=float,
help=(
"Clipping the max gradient norm can help prevent exploding gradients, but"
" may also harm training by introducing artifacts or making it hard to train artifacts away."
),
)
parser.add_argument(
"--push_to_hub",
Expand Down
13 changes: 13 additions & 0 deletions helpers/caching/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,13 +622,20 @@ def _process_images_in_batch(
)
for data in initial_data
]
first_aspect_ratio = None
for future in futures:
try:
result = (
future.result()
) # Returns (image, crop_coordinates, new_aspect_ratio)
if result: # Ensure result is not None or invalid
processed_images.append(result)
if first_aspect_ratio is None:
first_aspect_ratio = result[2]
elif aspect_bucket != first_aspect_ratio:
raise ValueError(
f"Image {filepath} has a different aspect ratio ({aspect_bucket}) than the first image in the batch ({first_aspect_ratio})."
)
except Exception as e:
self.debug_log(
f"Error processing image in pool: {e}, traceback: {traceback.format_exc()}"
Expand All @@ -637,11 +644,17 @@ def _process_images_in_batch(
# Second Loop: Final Processing
is_final_sample = False
output_values = []
first_aspect_ratio = None
for idx, (image, crop_coordinates, new_aspect_ratio) in enumerate(
processed_images
):
if idx == len(processed_images) - 1:
is_final_sample = True
if first_aspect_ratio is None:
first_aspect_ratio = new_aspect_ratio
elif new_aspect_ratio != first_aspect_ratio:
is_final_sample = True
first_aspect_ratio = new_aspect_ratio
filepath, _, aspect_bucket = initial_data[idx]
filepaths.append(filepath)

Expand Down
2 changes: 2 additions & 0 deletions helpers/data_backend/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ def print_bucket_info(metadata_backend):
# Print each bucket's information
for bucket in metadata_backend.aspect_ratio_bucket_indices:
image_count = len(metadata_backend.aspect_ratio_bucket_indices[bucket])
if image_count == 0:
continue
print(f"{rank_info()} | {bucket:<10} | {image_count:<12}")


Expand Down
38 changes: 25 additions & 13 deletions helpers/metadata/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ def split_buckets_between_processes(self, gradient_accumulation_steps=1):
)
if total_images != post_total:
self.read_only = True

logger.debug(f"Count of items after split: {post_total}")

def mark_as_seen(self, image_path):
Expand Down Expand Up @@ -386,14 +387,20 @@ def _enforce_min_bucket_size(self):
"""
Remove buckets that have fewer samples than batch_size and enforce minimum image size constraints.
"""
for bucket in list(
self.aspect_ratio_bucket_indices.keys()
logger.info(
f"Enforcing minimum image size of {self.minimum_image_size}."
" This could take a while for very-large datasets."
)
for bucket in tqdm(
list(self.aspect_ratio_bucket_indices.keys()),
leave=False,
desc="Enforcing minimum bucket size",
): # Safe iteration over keys
# # Prune the smaller buckets so that we don't enforce resolution constraints on them unnecessarily.
# self._prune_small_buckets(bucket)
# Prune the smaller buckets so that we don't enforce resolution constraints on them unnecessarily.
self._prune_small_buckets(bucket)
self._enforce_resolution_constraints(bucket)
# # We do this twice in case there were any new contenders for being too small.
# self._prune_small_buckets(bucket)
# We do this twice in case there were any new contenders for being too small.
self._prune_small_buckets(bucket)

def _prune_small_buckets(self, bucket):
"""
Expand All @@ -411,10 +418,6 @@ def _enforce_resolution_constraints(self, bucket):
Enforce resolution constraints on images in a bucket.
"""
if self.minimum_image_size is not None:
logger.info(
f"Enforcing minimum image size of {self.minimum_image_size}."
" This could take a while for very-large datasets."
)
if bucket not in self.aspect_ratio_bucket_indices:
logger.debug(
f"Bucket {bucket} was already removed due to insufficient samples."
Expand Down Expand Up @@ -736,19 +739,25 @@ def handle_vae_cache_inconsistencies(self, vae_cache, vae_cache_behavior: str):
# Update any state or metadata post-processing
self.save_cache()

def _recalculate_target_resolution(self, original_resolution: tuple) -> tuple:
def _recalculate_target_resolution(
self, original_resolution: tuple, original_aspect_ratio: float = None
) -> tuple:
"""Given the original resolution, use our backend config to properly recalculate the size."""
resolution_type = StateTracker.get_data_backend_config(self.id)[
"resolution_type"
]
resolution = StateTracker.get_data_backend_config(self.id)["resolution"]
if resolution_type == "pixel":
return MultiaspectImage.calculate_new_size_by_pixel_area(
return MultiaspectImage.calculate_new_size_by_pixel_edge(
original_resolution[0], original_resolution[1], resolution
)
elif resolution_type == "area":
if original_aspect_ratio is None:
raise ValueError(
"Original aspect ratio must be provided for area-based resolution."
)
return MultiaspectImage.calculate_new_size_by_pixel_area(
original_resolution[0], original_resolution[1], resolution
original_aspect_ratio, resolution
)

def is_cache_inconsistent(self, vae_cache, cache_file, cache_content):
Expand Down Expand Up @@ -786,6 +795,9 @@ def is_cache_inconsistent(self, vae_cache, cache_file, cache_content):
)
return True
target_resolution = tuple(metadata_target_size)
original_aspect_ratio = MultiaspectImage.calculate_image_aspect_ratio(
original_resolution
)
recalculated_width, recalculated_height, recalculated_aspect_ratio = (
self._recalculate_target_resolution(original_resolution)
)
Expand Down
13 changes: 5 additions & 8 deletions helpers/metadata/backends/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,19 +229,16 @@ def _process_for_bucket(
image_metadata["crop_coordinates"] = crop_coordinates
image_metadata["target_size"] = image.size
# Round to avoid excessive unique buckets
aspect_ratio = MultiaspectImage.calculate_image_aspect_ratio(
image, aspect_ratio_rounding
)
image_metadata["aspect_ratio"] = aspect_ratio
image_metadata["aspect_ratio"] = new_aspect_ratio
image_metadata["luminance"] = calculate_luminance(image)
logger.debug(
f"Image {image_path_str} has aspect ratio {aspect_ratio} and size {image.size}."
f"Image {image_path_str} has aspect ratio {new_aspect_ratio} and size {image.size}."
)

# Create a new bucket if it doesn't exist
if str(aspect_ratio) not in aspect_ratio_bucket_indices:
aspect_ratio_bucket_indices[str(aspect_ratio)] = []
aspect_ratio_bucket_indices[str(aspect_ratio)].append(image_path_str)
if str(new_aspect_ratio) not in aspect_ratio_bucket_indices:
aspect_ratio_bucket_indices[str(new_aspect_ratio)] = []
aspect_ratio_bucket_indices[str(new_aspect_ratio)].append(image_path_str)
# Instead of directly updating, just fill the provided dictionary
if metadata_updates is not None:
metadata_updates[image_path_str] = image_metadata
Expand Down
8 changes: 4 additions & 4 deletions helpers/metadata/backends/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def _process_for_bucket(
)
image_metadata["crop_coordinates"] = crop_coordinates
image_metadata["target_size"] = target_size
image_metadata["aspect_ratio"] = aspect_ratio
image_metadata["aspect_ratio"] = new_aspect_ratio
luminance_column = self.parquet_config.get("luminance_column", None)
if luminance_column:
image_metadata["luminance"] = database_image_metadata[
Expand All @@ -377,9 +377,9 @@ def _process_for_bucket(
)

# Create a new bucket if it doesn't exist
if str(aspect_ratio) not in aspect_ratio_bucket_indices:
aspect_ratio_bucket_indices[str(aspect_ratio)] = []
aspect_ratio_bucket_indices[str(aspect_ratio)].append(image_path_str)
if str(new_aspect_ratio) not in aspect_ratio_bucket_indices:
aspect_ratio_bucket_indices[str(new_aspect_ratio)] = []
aspect_ratio_bucket_indices[str(new_aspect_ratio)].append(image_path_str)
# Instead of directly updating, just fill the provided dictionary
if metadata_updates is not None:
metadata_updates[image_path_str] = image_metadata
Expand Down
7 changes: 7 additions & 0 deletions helpers/multiaspect/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,15 @@ def __len__(self):

def __getitem__(self, image_tuple):
output_data = []
first_aspect_ratio = None
for sample in image_tuple:
image_metadata = sample
if first_aspect_ratio is None:
first_aspect_ratio = image_metadata["aspect_ratio"]
elif first_aspect_ratio != image_metadata["aspect_ratio"]:
raise ValueError(
f"Aspect ratios must be the same for all images in a batch. Expected: {first_aspect_ratio}, got: {image_metadata['aspect_ratio']}"
)
if (
image_metadata["original_size"] is None
or image_metadata["target_size"] is None
Expand Down
22 changes: 13 additions & 9 deletions helpers/multiaspect/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,12 @@ def prepare_image(
)
)
elif resolution_type == "area":
original_aspect_ratio = MultiaspectImage.calculate_image_aspect_ratio(
(original_width, original_height)
)
(target_width, target_height, new_aspect_ratio) = (
MultiaspectImage.calculate_new_size_by_pixel_area(
original_width, original_height, resolution
original_aspect_ratio, resolution
)
)
# Convert 'resolution' from eg. "1 megapixel" to "1024 pixels"
Expand Down Expand Up @@ -129,8 +132,7 @@ def prepare_image(
original_megapixel_resolution = original_resolution / 1e3
(target_width, target_height, new_aspect_ratio) = (
MultiaspectImage.calculate_new_size_by_pixel_area(
original_width,
original_height,
original_aspect_ratio,
original_megapixel_resolution,
)
)
Expand Down Expand Up @@ -345,7 +347,9 @@ def _resize_image(

new_image_size = (intermediate_width, intermediate_height)
if input_image:
input_image = input_image.resize(new_image_size, resample=Image.LANCZOS)
input_image = input_image.resize(
new_image_size, resample=Image.Resampling.LANCZOS
)
current_width, current_height = new_image_size
logger.debug(
f"Resized image to intermediate size: {current_width}x{current_height}."
Expand All @@ -354,7 +358,7 @@ def _resize_image(
# Final resize to target dimensions
if input_image:
input_image = input_image.resize(
(target_width, target_height), resample=Image.LANCZOS
(target_width, target_height), resample=Image.Resampling.LANCZOS
)
logger.debug(f"Final image size: {input_image.size}.")
return input_image
Expand Down Expand Up @@ -389,7 +393,7 @@ def is_image_too_large(image_size: tuple, resolution: float, resolution_type: st
raise ValueError(f"Unknown resolution type: {resolution_type}")

@staticmethod
def calculate_new_size_by_pixel_edge(W: int, H: int, resolution: float):
def calculate_new_size_by_pixel_edge(W: int, H: int, resolution: int):
aspect_ratio = MultiaspectImage.calculate_image_aspect_ratio((W, H))
if W < H:
W = resolution
Expand All @@ -408,9 +412,9 @@ def calculate_new_size_by_pixel_edge(W: int, H: int, resolution: float):
return int(W), int(H), new_aspect_ratio

@staticmethod
def calculate_new_size_by_pixel_area(W: int, H: int, megapixels: float):
# Calculate initial dimensions based on aspect ratio and target megapixels
aspect_ratio = MultiaspectImage.calculate_image_aspect_ratio((W, H))
def calculate_new_size_by_pixel_area(aspect_ratio: float, megapixels: float):
if type(aspect_ratio) != float:
raise ValueError(f"Aspect ratio must be a float, not {type(aspect_ratio)}")
total_pixels = max(megapixels * 1e6, 1e6)
W_initial = int(round((total_pixels * aspect_ratio) ** 0.5))
H_initial = int(round((total_pixels / aspect_ratio) ** 0.5))
Expand Down
6 changes: 6 additions & 0 deletions helpers/multiaspect/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,11 +341,17 @@ def __iter__(self):
to_yield = self._validate_and_yield_images_from_samples(
samples, self.buckets[self.current_bucket]
)
self.debug_log(
f"Building batch with {len(self.batch_accumulator)} samples."
)
if len(self.batch_accumulator) < self.batch_size:
remaining_entries_needed = self.batch_size - len(
self.batch_accumulator
)
# Now we'll add only remaining_entries_needed amount to the accumulator:
self.debug_log(
f"Current bucket: {self.current_bucket}. Adding samples with aspect ratios: {[i['aspect_ratio'] for i in to_yield[:remaining_entries_needed]]}"
)
self.batch_accumulator.extend(to_yield[:remaining_entries_needed])
# If the batch is full, yield it
if len(self.batch_accumulator) >= self.batch_size:
Expand Down
55 changes: 29 additions & 26 deletions helpers/training/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,27 +101,7 @@ def compute_latents(filepaths, data_backend_id: str):
logger.error(f"(id={data_backend_id}) Error while computing latents: {e}")
raise

# Validate shapes
test_shape = latents[0].shape
for idx, latent in enumerate(latents):
# Are there any inf or nan positions?
if torch.isnan(latent).any() or torch.isinf(latent).any():
# get the data_backend
data_backend = StateTracker.get_data_backend(data_backend_id)
# remove the object
data_backend["vaecache"].data_backend.delete(filepaths[idx])
raise ValueError(
f"(id={data_backend_id}) Deleted cache file {filepaths[idx]}: contains NaN or Inf values: {latent}"
)
if latent.shape != test_shape:
raise ValueError(
f"(id={data_backend_id}) File {filepaths[idx]} latent shape mismatch: {latent.shape} != {test_shape}"
)

debug_log(f" -> stacking {len(latents)} latents")
return torch.stack(
[latent.to(StateTracker.get_accelerator().device) for latent in latents]
)
return latents


def compute_single_embedding(caption, text_embed_cache, is_sdxl):
Expand Down Expand Up @@ -212,11 +192,34 @@ def gather_conditional_size_features(examples, latents, weight_dtype):
return torch.stack(batch_time_ids_list, dim=0)


def check_latent_shapes(latents, filepaths):
reference_shape = latents[0].shape
def check_latent_shapes(latents, filepaths, data_backend_id, batch):
# Validate shapes
test_shape = latents[0].shape
# Check all "aspect_ratio" values and raise error if any differ, with the two differing values:
for example in batch[0]:
if example["aspect_ratio"] != batch[0][0]["aspect_ratio"]:
raise ValueError(
f"Aspect ratio mismatch: {example['aspect_ratio']} != {batch[0][0]['aspect_ratio']}"
)
for idx, latent in enumerate(latents):
if latent.shape != reference_shape:
print(f"Latent shape mismatch for file: {filepaths[idx]}")
# Are there any inf or nan positions?
if torch.isnan(latent).any() or torch.isinf(latent).any():
# get the data_backend
data_backend = StateTracker.get_data_backend(data_backend_id)
# remove the object
data_backend["vaecache"].data_backend.delete(filepaths[idx])
raise ValueError(
f"(id={data_backend_id}) Deleted cache file {filepaths[idx]}: contains NaN or Inf values: {latent}"
)
if latent.shape != test_shape:
raise ValueError(
f"(id={data_backend_id}) File {filepaths[idx]} latent shape mismatch: {latent.shape} != {test_shape}"
)

debug_log(f" -> stacking {len(latents)} latents")
return torch.stack(
[latent.to(StateTracker.get_accelerator().device) for latent in latents]
)


def collate_fn(batch):
Expand Down Expand Up @@ -252,7 +255,7 @@ def collate_fn(batch):
debug_log("Compute latents")
latent_batch = compute_latents(filepaths, data_backend_id)
debug_log("Check latents")
check_latent_shapes(latent_batch, filepaths)
latent_batch = check_latent_shapes(latent_batch, filepaths, data_backend_id, batch)

# Compute embeddings and handle dropped conditionings
debug_log("Extract captions")
Expand Down
2 changes: 1 addition & 1 deletion kohya_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@
"scale_weight_norms": {
"parameter": "max_grad_norm",
"warn_if_value": 0,
"warning": "In SimpleTuner, max_grad_norm is set to 1 by default. Please change this if you see issues.",
"warning": "In SimpleTuner, max_grad_norm is set to 2 by default. Please change this if you see issues.",
},
"sdxl": {"script_name": "train_sdxl.py"},
"sdxl_cache_text_encoder_outputs": {
Expand Down
Loading
Loading