From ddc93f812e2ffafbc5fb626851649834d7caca76 Mon Sep 17 00:00:00 2001 From: bghira Date: Thu, 11 Apr 2024 07:49:26 -0600 Subject: [PATCH 01/18] fix deprecation warning in Pillow --- helpers/multiaspect/image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/helpers/multiaspect/image.py b/helpers/multiaspect/image.py index 0be77d39..92bdcade 100644 --- a/helpers/multiaspect/image.py +++ b/helpers/multiaspect/image.py @@ -354,7 +354,7 @@ def _resize_image( # Final resize to target dimensions if input_image: input_image = input_image.resize( - (target_width, target_height), resample=Image.LANCZOS + (target_width, target_height), resample=Image.Resampling.LANCZOS ) logger.debug(f"Final image size: {input_image.size}.") return input_image From ad718370be47f539ac184e572b09eaeadcd56b17 Mon Sep 17 00:00:00 2001 From: bghira Date: Thu, 11 Apr 2024 09:33:52 -0600 Subject: [PATCH 02/18] fix deprecation warning in Pillow (missed one) --- helpers/multiaspect/image.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/helpers/multiaspect/image.py b/helpers/multiaspect/image.py index 92bdcade..511bd666 100644 --- a/helpers/multiaspect/image.py +++ b/helpers/multiaspect/image.py @@ -345,7 +345,9 @@ def _resize_image( new_image_size = (intermediate_width, intermediate_height) if input_image: - input_image = input_image.resize(new_image_size, resample=Image.LANCZOS) + input_image = input_image.resize( + new_image_size, resample=Image.Resampling.LANCZOS + ) current_width, current_height = new_image_size logger.debug( f"Resized image to intermediate size: {current_width}x{current_height}." From 7d7de073894e2444c30ae94d1ab4e7c20036021f Mon Sep 17 00:00:00 2001 From: bghira Date: Fri, 12 Apr 2024 08:30:16 -0600 Subject: [PATCH 03/18] add regression test for mixed latent issue --- helpers/metadata/backends/base.py | 16 ++++- helpers/metadata/backends/parquet.py | 4 ++ helpers/multiaspect/image.py | 16 +++-- tests/test_image.py | 98 +++++++++++++++++++++++++++- 4 files changed, 121 insertions(+), 13 deletions(-) diff --git a/helpers/metadata/backends/base.py b/helpers/metadata/backends/base.py index a4873e6a..f53fd339 100644 --- a/helpers/metadata/backends/base.py +++ b/helpers/metadata/backends/base.py @@ -302,6 +302,7 @@ def split_buckets_between_processes(self, gradient_accumulation_steps=1): ) if total_images != post_total: self.read_only = True + logger.debug(f"Count of items after split: {post_total}") def mark_as_seen(self, image_path): @@ -736,19 +737,25 @@ def handle_vae_cache_inconsistencies(self, vae_cache, vae_cache_behavior: str): # Update any state or metadata post-processing self.save_cache() - def _recalculate_target_resolution(self, original_resolution: tuple) -> tuple: + def _recalculate_target_resolution( + self, original_resolution: tuple, original_aspect_ratio: float = None + ) -> tuple: """Given the original resolution, use our backend config to properly recalculate the size.""" resolution_type = StateTracker.get_data_backend_config(self.id)[ "resolution_type" ] resolution = StateTracker.get_data_backend_config(self.id)["resolution"] if resolution_type == "pixel": - return MultiaspectImage.calculate_new_size_by_pixel_area( + return MultiaspectImage.calculate_new_size_by_pixel_edge( original_resolution[0], original_resolution[1], resolution ) elif resolution_type == "area": + if original_aspect_ratio is None: + raise ValueError( + "Original aspect ratio must be provided for area-based resolution." + ) return MultiaspectImage.calculate_new_size_by_pixel_area( - original_resolution[0], original_resolution[1], resolution + original_aspect_ratio, resolution ) def is_cache_inconsistent(self, vae_cache, cache_file, cache_content): @@ -786,6 +793,9 @@ def is_cache_inconsistent(self, vae_cache, cache_file, cache_content): ) return True target_resolution = tuple(metadata_target_size) + original_aspect_ratio = MultiaspectImage.calculate_image_aspect_ratio( + original_resolution + ) recalculated_width, recalculated_height, recalculated_aspect_ratio = ( self._recalculate_target_resolution(original_resolution) ) diff --git a/helpers/metadata/backends/parquet.py b/helpers/metadata/backends/parquet.py index 4ab85fc1..5d9f6670 100644 --- a/helpers/metadata/backends/parquet.py +++ b/helpers/metadata/backends/parquet.py @@ -362,6 +362,10 @@ def _process_for_bucket( id=self.data_backend.id, ) ) + if aspect_ratio != new_aspect_ratio: + raise ValueError( + f"Aspect ratio mismatch: {aspect_ratio} vs {new_aspect_ratio} for sample: {image_path_str}." + ) image_metadata["crop_coordinates"] = crop_coordinates image_metadata["target_size"] = target_size image_metadata["aspect_ratio"] = aspect_ratio diff --git a/helpers/multiaspect/image.py b/helpers/multiaspect/image.py index 511bd666..751982db 100644 --- a/helpers/multiaspect/image.py +++ b/helpers/multiaspect/image.py @@ -89,9 +89,12 @@ def prepare_image( ) ) elif resolution_type == "area": + original_aspect_ratio = MultiaspectImage.calculate_image_aspect_ratio( + (original_width, original_height) + ) (target_width, target_height, new_aspect_ratio) = ( MultiaspectImage.calculate_new_size_by_pixel_area( - original_width, original_height, resolution + original_aspect_ratio, resolution ) ) # Convert 'resolution' from eg. "1 megapixel" to "1024 pixels" @@ -129,8 +132,7 @@ def prepare_image( original_megapixel_resolution = original_resolution / 1e3 (target_width, target_height, new_aspect_ratio) = ( MultiaspectImage.calculate_new_size_by_pixel_area( - original_width, - original_height, + original_aspect_ratio, original_megapixel_resolution, ) ) @@ -391,7 +393,7 @@ def is_image_too_large(image_size: tuple, resolution: float, resolution_type: st raise ValueError(f"Unknown resolution type: {resolution_type}") @staticmethod - def calculate_new_size_by_pixel_edge(W: int, H: int, resolution: float): + def calculate_new_size_by_pixel_edge(W: int, H: int, resolution: int): aspect_ratio = MultiaspectImage.calculate_image_aspect_ratio((W, H)) if W < H: W = resolution @@ -410,9 +412,9 @@ def calculate_new_size_by_pixel_edge(W: int, H: int, resolution: float): return int(W), int(H), new_aspect_ratio @staticmethod - def calculate_new_size_by_pixel_area(W: int, H: int, megapixels: float): - # Calculate initial dimensions based on aspect ratio and target megapixels - aspect_ratio = MultiaspectImage.calculate_image_aspect_ratio((W, H)) + def calculate_new_size_by_pixel_area(aspect_ratio: float, megapixels: float): + if type(aspect_ratio) != float: + raise ValueError(f"Aspect ratio must be a float, not {type(aspect_ratio)}") total_pixels = max(megapixels * 1e6, 1e6) W_initial = int(round((total_pixels * aspect_ratio) ** 0.5)) H_initial = int(round((total_pixels / aspect_ratio) ** 0.5)) diff --git a/tests/test_image.py b/tests/test_image.py index 202a69e4..7268df7c 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -29,6 +29,98 @@ def setUp(self): self.mock_image_data = image_bytes.getvalue() self.data_backend.read.return_value = self.mock_image_data + def test_aspect_ratio_calculation(self): + """ + Test that the aspect ratio calculation returns expected results. + """ + self.assertEqual( + MultiaspectImage.calculate_image_aspect_ratio((1920, 1080)), 1.78 + ) + self.assertEqual( + MultiaspectImage.calculate_image_aspect_ratio((1080, 1920)), 0.56 + ) + + def test_image_resize(self): + """ + Test that images are resized to the expected dimensions. + """ + # Create a mock image + original_image = Image.new("RGB", (1920, 1080)) + + # Define target resolutions and expected output sizes + tests = [ + (1024, "pixel", (1824, 1024)), + (1.0, "area", (1344, 768)), # Assuming target is 1 megapixel + ] + + for resolution, resolution_type, expected_size in tests: + resized_image, _, _ = MultiaspectImage.prepare_image( + resolution=resolution, + image=original_image, + resolution_type=resolution_type, + id="test", + ) + + # Verify the size of the resized image + self.assertEqual(resized_image.size, expected_size) + + def test_image_size_consistency(self): + """ + Test that `prepare_image` returns consistent size for images with similar aspect ratios. + """ + # Generate random input aspect ratios and resolutions: + input_aspect_ratios = [random.uniform(0.5, 2.0) for _ in range(10)] + # Sizes should follow the list of resolutions, with between 2-4 images in each aspect + input_sizes = [] + for aspect_ratio in input_aspect_ratios: + count = 0 + for resolution in range(5, 50, 5): + count += 1 + width = resolution * 100 + height = int(width / aspect_ratio) + input_sizes.append((width, height)) + + # Sort into bucket dictionary using MultiaspectImage.calculate_image_aspect_ratio + input_sizes_dict = {} + for size in input_sizes: + aspect_ratio = MultiaspectImage.calculate_image_aspect_ratio(size) + if aspect_ratio not in input_sizes_dict: + input_sizes_dict[aspect_ratio] = [] + input_sizes_dict[aspect_ratio].append(size) + + resolutions = range( + 5, 20, 5 + ) # Using a simplified resolution from the logs for the test + for aspect_ratio in set(input_sizes_dict.keys()): + for resolution in resolutions: + resolution = resolution / 10 # Convert to megapixels + output_sizes = [] + for size in input_sizes_dict[aspect_ratio]: + should_use_real_image = random.choice([True, False]) + image = ( + Image.new("RGB", size) if should_use_real_image else None + ) # Creating a dummy PIL image with the given size + image_metadata = ( + None if should_use_real_image else {"original_size": size} + ) + function_result, _, _ = MultiaspectImage.prepare_image( + image=image, + image_metadata=image_metadata, + resolution=resolution, + resolution_type="area", + ) + if hasattr(function_result, "size"): + output_size = function_result.size + else: + output_size = function_result + output_sizes.append(output_size) + + # Check if all output sizes are the same, indicating consistent resizing/cropping + self.assertTrue( + all(size == output_sizes[0] for size in output_sizes), + f"Output sizes are not consistent for {resolution} MP", + ) + def test_crop_corner(self): cropped_image, _ = MultiaspectImage._crop_corner( self.test_image, self.resolution, self.resolution @@ -89,7 +181,7 @@ def test_calculate_new_size_by_pixel_area(self): # Calculate new size new_width, new_height, new_aspect_ratio = ( MultiaspectImage.calculate_new_size_by_pixel_area( - original_width, original_height, mp + original_aspect_ratio, mp ) ) @@ -123,7 +215,7 @@ def test_calculate_new_size_by_pixel_area_uniformity(self): for W, H, megapixels in test_cases: W_final, H_final, _ = MultiaspectImage.calculate_new_size_by_pixel_area( - W, H, megapixels + MultiaspectImage.calculate_image_aspect_ratio((W, H)), megapixels ) self.assertEqual( (W_final, H_final), expected_size, f"Failed for original size {W}x{H}" @@ -150,7 +242,7 @@ def test_calculate_new_size_by_pixel_area_squares(self): for W, H, megapixels in test_cases: W_final, H_final, _ = MultiaspectImage.calculate_new_size_by_pixel_area( - W, H, megapixels + MultiaspectImage.calculate_image_aspect_ratio((W, H)), megapixels ) self.assertEqual( (W_final, H_final), expected_size, f"Failed for original size {W}x{H}" From 87335882fb6c14588c3efcafeed305df2d022238 Mon Sep 17 00:00:00 2001 From: bghira Date: Fri, 12 Apr 2024 09:04:00 -0600 Subject: [PATCH 04/18] add error check for production issue --- helpers/caching/vae.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/helpers/caching/vae.py b/helpers/caching/vae.py index cf9cb81a..c33923a4 100644 --- a/helpers/caching/vae.py +++ b/helpers/caching/vae.py @@ -590,12 +590,19 @@ def _process_images_in_batch( logger.debug(f"we have {qlen} images to process.") # First Loop: Preparation and Filtering + first_aspect_ratio = None for _ in range(qlen): if image_paths: # retrieve image data from Generator, image_data: filepath = image_paths.pop() image = image_data.pop() aspect_bucket = MultiaspectImage.calculate_image_aspect_ratio(image) + if first_aspect_ratio is None: + first_aspect_ratio = aspect_bucket + elif aspect_bucket != first_aspect_ratio: + raise ValueError( + f"Image {filepath} has a different aspect ratio ({aspect_bucket}) than the first image in the batch ({first_aspect_ratio})." + ) else: filepath, image, aspect_bucket = self.process_queue.get() if self.minimum_image_size is not None: From 27efcc0f0b1e76a2d5f619fc7e171d914edd1529 Mon Sep 17 00:00:00 2001 From: bghira Date: Fri, 12 Apr 2024 09:36:37 -0600 Subject: [PATCH 05/18] use new aspect for parquet --- helpers/metadata/backends/parquet.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/helpers/metadata/backends/parquet.py b/helpers/metadata/backends/parquet.py index 5d9f6670..642b8422 100644 --- a/helpers/metadata/backends/parquet.py +++ b/helpers/metadata/backends/parquet.py @@ -362,13 +362,9 @@ def _process_for_bucket( id=self.data_backend.id, ) ) - if aspect_ratio != new_aspect_ratio: - raise ValueError( - f"Aspect ratio mismatch: {aspect_ratio} vs {new_aspect_ratio} for sample: {image_path_str}." - ) image_metadata["crop_coordinates"] = crop_coordinates image_metadata["target_size"] = target_size - image_metadata["aspect_ratio"] = aspect_ratio + image_metadata["aspect_ratio"] = new_aspect_ratio luminance_column = self.parquet_config.get("luminance_column", None) if luminance_column: image_metadata["luminance"] = database_image_metadata[ From f1fbdb6a3fba9c9fadd7e926456707010e66b4ef Mon Sep 17 00:00:00 2001 From: bghira Date: Fri, 12 Apr 2024 10:05:05 -0600 Subject: [PATCH 06/18] remove test since it is going on image size and not target size. instead, test after processing. --- helpers/caching/vae.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/helpers/caching/vae.py b/helpers/caching/vae.py index c33923a4..162351b9 100644 --- a/helpers/caching/vae.py +++ b/helpers/caching/vae.py @@ -590,19 +590,12 @@ def _process_images_in_batch( logger.debug(f"we have {qlen} images to process.") # First Loop: Preparation and Filtering - first_aspect_ratio = None for _ in range(qlen): if image_paths: # retrieve image data from Generator, image_data: filepath = image_paths.pop() image = image_data.pop() aspect_bucket = MultiaspectImage.calculate_image_aspect_ratio(image) - if first_aspect_ratio is None: - first_aspect_ratio = aspect_bucket - elif aspect_bucket != first_aspect_ratio: - raise ValueError( - f"Image {filepath} has a different aspect ratio ({aspect_bucket}) than the first image in the batch ({first_aspect_ratio})." - ) else: filepath, image, aspect_bucket = self.process_queue.get() if self.minimum_image_size is not None: @@ -629,6 +622,7 @@ def _process_images_in_batch( ) for data in initial_data ] + first_aspect_ratio = None for future in futures: try: result = ( @@ -636,6 +630,12 @@ def _process_images_in_batch( ) # Returns (image, crop_coordinates, new_aspect_ratio) if result: # Ensure result is not None or invalid processed_images.append(result) + if first_aspect_ratio is None: + first_aspect_ratio = result[2] + elif aspect_bucket != first_aspect_ratio: + raise ValueError( + f"Image {filepath} has a different aspect ratio ({aspect_bucket}) than the first image in the batch ({first_aspect_ratio})." + ) except Exception as e: self.debug_log( f"Error processing image in pool: {e}, traceback: {traceback.format_exc()}" From 9de1e7ae5273c443890b80353e6783542ad77a79 Mon Sep 17 00:00:00 2001 From: bghira Date: Fri, 12 Apr 2024 10:09:02 -0600 Subject: [PATCH 07/18] revise logging for bucket size enforcement --- helpers/metadata/backends/base.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/helpers/metadata/backends/base.py b/helpers/metadata/backends/base.py index f53fd339..3f49ad2b 100644 --- a/helpers/metadata/backends/base.py +++ b/helpers/metadata/backends/base.py @@ -387,9 +387,20 @@ def _enforce_min_bucket_size(self): """ Remove buckets that have fewer samples than batch_size and enforce minimum image size constraints. """ - for bucket in list( - self.aspect_ratio_bucket_indices.keys() + logger.info( + f"Enforcing minimum image size of {self.minimum_image_size}." + " This could take a while for very-large datasets." + ) + for bucket in tqdm( + list(self.aspect_ratio_bucket_indices.keys()), + leave=False, + desc="Enforcing minimum bucket size", ): # Safe iteration over keys + # If the bucket is empty, remove it + if not self.aspect_ratio_bucket_indices[bucket]: + logger.debug(f"Bucket {bucket} is empty. Removing.") + del self.aspect_ratio_bucket_indices[bucket] + continue # # Prune the smaller buckets so that we don't enforce resolution constraints on them unnecessarily. # self._prune_small_buckets(bucket) self._enforce_resolution_constraints(bucket) @@ -412,10 +423,6 @@ def _enforce_resolution_constraints(self, bucket): Enforce resolution constraints on images in a bucket. """ if self.minimum_image_size is not None: - logger.info( - f"Enforcing minimum image size of {self.minimum_image_size}." - " This could take a while for very-large datasets." - ) if bucket not in self.aspect_ratio_bucket_indices: logger.debug( f"Bucket {bucket} was already removed due to insufficient samples." From 449bb9e1de91ef1d1619fbc23b8b32be67e61ef4 Mon Sep 17 00:00:00 2001 From: bghira Date: Fri, 12 Apr 2024 10:14:53 -0600 Subject: [PATCH 08/18] reinstate bucket pruning --- helpers/metadata/backends/base.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/helpers/metadata/backends/base.py b/helpers/metadata/backends/base.py index 3f49ad2b..6b0437ed 100644 --- a/helpers/metadata/backends/base.py +++ b/helpers/metadata/backends/base.py @@ -396,16 +396,11 @@ def _enforce_min_bucket_size(self): leave=False, desc="Enforcing minimum bucket size", ): # Safe iteration over keys - # If the bucket is empty, remove it - if not self.aspect_ratio_bucket_indices[bucket]: - logger.debug(f"Bucket {bucket} is empty. Removing.") - del self.aspect_ratio_bucket_indices[bucket] - continue - # # Prune the smaller buckets so that we don't enforce resolution constraints on them unnecessarily. - # self._prune_small_buckets(bucket) + # Prune the smaller buckets so that we don't enforce resolution constraints on them unnecessarily. + self._prune_small_buckets(bucket) self._enforce_resolution_constraints(bucket) - # # We do this twice in case there were any new contenders for being too small. - # self._prune_small_buckets(bucket) + # We do this twice in case there were any new contenders for being too small. + self._prune_small_buckets(bucket) def _prune_small_buckets(self, bucket): """ From da6404f6b3d6a5f2da01751256fdd567e3952b21 Mon Sep 17 00:00:00 2001 From: bghira Date: Fri, 12 Apr 2024 10:17:28 -0600 Subject: [PATCH 09/18] no printing of zero-sized buckets --- helpers/data_backend/factory.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/helpers/data_backend/factory.py b/helpers/data_backend/factory.py index 72afa7dd..10b4ede2 100644 --- a/helpers/data_backend/factory.py +++ b/helpers/data_backend/factory.py @@ -133,6 +133,8 @@ def print_bucket_info(metadata_backend): # Print each bucket's information for bucket in metadata_backend.aspect_ratio_bucket_indices: image_count = len(metadata_backend.aspect_ratio_bucket_indices[bucket]) + if image_count == 0: + continue print(f"{rank_info()} | {bucket:<10} | {image_count:<12}") From 26c5e1d89c35fc2b6ce566e616c5f56ac0ad8fb0 Mon Sep 17 00:00:00 2001 From: bghira Date: Fri, 12 Apr 2024 18:52:09 -0600 Subject: [PATCH 10/18] check aspect ratio and stuff the batches up the ass --- helpers/caching/vae.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/helpers/caching/vae.py b/helpers/caching/vae.py index 162351b9..e29c9c73 100644 --- a/helpers/caching/vae.py +++ b/helpers/caching/vae.py @@ -644,11 +644,17 @@ def _process_images_in_batch( # Second Loop: Final Processing is_final_sample = False output_values = [] + first_aspect_ratio = None for idx, (image, crop_coordinates, new_aspect_ratio) in enumerate( processed_images ): if idx == len(processed_images) - 1: is_final_sample = True + if first_aspect_ratio is None: + first_aspect_ratio = new_aspect_ratio + elif new_aspect_ratio != first_aspect_ratio: + is_final_sample = True + first_aspect_ratio = new_aspect_ratio filepath, _, aspect_bucket = initial_data[idx] filepaths.append(filepath) From 2a35b252139b65ba68c6e05c3c831066e2f9fa3c Mon Sep 17 00:00:00 2001 From: bghira Date: Sat, 13 Apr 2024 08:21:18 -0600 Subject: [PATCH 11/18] Log the aspect ratios for each input sample --- helpers/training/collate.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/helpers/training/collate.py b/helpers/training/collate.py index e206bb7a..167bd12a 100644 --- a/helpers/training/collate.py +++ b/helpers/training/collate.py @@ -212,11 +212,13 @@ def gather_conditional_size_features(examples, latents, weight_dtype): return torch.stack(batch_time_ids_list, dim=0) -def check_latent_shapes(latents, filepaths): +def check_latent_shapes(latents, filepaths, batch): reference_shape = latents[0].shape for idx, latent in enumerate(latents): if latent.shape != reference_shape: - print(f"Latent shape mismatch for file: {filepaths[idx]}") + print( + f"Latent shape mismatch for file: {filepaths[idx]}, aspect ratios: {[example['aspect_ratio'] for example in batch]}" + ) def collate_fn(batch): @@ -252,7 +254,7 @@ def collate_fn(batch): debug_log("Compute latents") latent_batch = compute_latents(filepaths, data_backend_id) debug_log("Check latents") - check_latent_shapes(latent_batch, filepaths) + check_latent_shapes(latent_batch, filepaths, batch) # Compute embeddings and handle dropped conditionings debug_log("Extract captions") From e6fac9b12b23ee6d3b84ab6fbcfa59738219e5ab Mon Sep 17 00:00:00 2001 From: bghira Date: Sat, 13 Apr 2024 09:16:00 -0600 Subject: [PATCH 12/18] move check to the check function, whoops --- helpers/training/collate.py | 49 +++++++++++++++++-------------------- tests/test_collate.py | 3 +++ 2 files changed, 25 insertions(+), 27 deletions(-) diff --git a/helpers/training/collate.py b/helpers/training/collate.py index 167bd12a..f07e7920 100644 --- a/helpers/training/collate.py +++ b/helpers/training/collate.py @@ -101,27 +101,7 @@ def compute_latents(filepaths, data_backend_id: str): logger.error(f"(id={data_backend_id}) Error while computing latents: {e}") raise - # Validate shapes - test_shape = latents[0].shape - for idx, latent in enumerate(latents): - # Are there any inf or nan positions? - if torch.isnan(latent).any() or torch.isinf(latent).any(): - # get the data_backend - data_backend = StateTracker.get_data_backend(data_backend_id) - # remove the object - data_backend["vaecache"].data_backend.delete(filepaths[idx]) - raise ValueError( - f"(id={data_backend_id}) Deleted cache file {filepaths[idx]}: contains NaN or Inf values: {latent}" - ) - if latent.shape != test_shape: - raise ValueError( - f"(id={data_backend_id}) File {filepaths[idx]} latent shape mismatch: {latent.shape} != {test_shape}" - ) - - debug_log(f" -> stacking {len(latents)} latents") - return torch.stack( - [latent.to(StateTracker.get_accelerator().device) for latent in latents] - ) + return latents def compute_single_embedding(caption, text_embed_cache, is_sdxl): @@ -212,14 +192,29 @@ def gather_conditional_size_features(examples, latents, weight_dtype): return torch.stack(batch_time_ids_list, dim=0) -def check_latent_shapes(latents, filepaths, batch): - reference_shape = latents[0].shape +def check_latent_shapes(latents, filepaths, data_backend_id): + # Validate shapes + test_shape = latents[0].shape for idx, latent in enumerate(latents): - if latent.shape != reference_shape: - print( - f"Latent shape mismatch for file: {filepaths[idx]}, aspect ratios: {[example['aspect_ratio'] for example in batch]}" + # Are there any inf or nan positions? + if torch.isnan(latent).any() or torch.isinf(latent).any(): + # get the data_backend + data_backend = StateTracker.get_data_backend(data_backend_id) + # remove the object + data_backend["vaecache"].data_backend.delete(filepaths[idx]) + raise ValueError( + f"(id={data_backend_id}) Deleted cache file {filepaths[idx]}: contains NaN or Inf values: {latent}" + ) + if latent.shape != test_shape: + raise ValueError( + f"(id={data_backend_id}) File {filepaths[idx]} latent shape mismatch: {latent.shape} != {test_shape}" ) + debug_log(f" -> stacking {len(latents)} latents") + return torch.stack( + [latent.to(StateTracker.get_accelerator().device) for latent in latents] + ) + def collate_fn(batch): if len(batch) != 1: @@ -254,7 +249,7 @@ def collate_fn(batch): debug_log("Compute latents") latent_batch = compute_latents(filepaths, data_backend_id) debug_log("Check latents") - check_latent_shapes(latent_batch, filepaths, batch) + latent_batch = check_latent_shapes(latent_batch, filepaths, data_backend_id) # Compute embeddings and handle dropped conditionings debug_log("Extract captions") diff --git a/tests/test_collate.py b/tests/test_collate.py index 9d4924c6..fbbc19f0 100644 --- a/tests/test_collate.py +++ b/tests/test_collate.py @@ -26,6 +26,8 @@ def setUp(self): ] # Mock StateTracker.get_args() to return a mock object with required attributes StateTracker.set_args(MagicMock(caption_dropout_probability=0.5)) + fake_accelerator = MagicMock(device="cpu") + StateTracker.set_accelerator(fake_accelerator) @patch("helpers.training.collate.compute_latents") @patch("helpers.training.collate.compute_prompt_embeddings") @@ -42,6 +44,7 @@ def test_collate_fn(self, mock_gather, mock_compute_embeds, mock_compute_latents mock_gather.return_value = torch.tensor( [[1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6]] ) + mock_compute_latents.to = MagicMock(return_value=mock_compute_latents) # Call collate_fn with a mock batch with patch("helpers.training.state_tracker.StateTracker.get_data_backend"): From 18b0aa05db7b1fe1226a96c509a7db1a2ec81641 Mon Sep 17 00:00:00 2001 From: bghira Date: Sat, 13 Apr 2024 12:12:46 -0600 Subject: [PATCH 13/18] fix test, enhance check --- helpers/training/collate.py | 10 ++++++++-- tests/test_collate.py | 1 + 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/helpers/training/collate.py b/helpers/training/collate.py index f07e7920..6b168cdc 100644 --- a/helpers/training/collate.py +++ b/helpers/training/collate.py @@ -192,9 +192,15 @@ def gather_conditional_size_features(examples, latents, weight_dtype): return torch.stack(batch_time_ids_list, dim=0) -def check_latent_shapes(latents, filepaths, data_backend_id): +def check_latent_shapes(latents, filepaths, data_backend_id, batch): # Validate shapes test_shape = latents[0].shape + # Check all "aspect_ratio" values and raise error if any differ, with the two differing values: + for example in batch[0]: + if example["aspect_ratio"] != batch[0][0]["aspect_ratio"]: + raise ValueError( + f"Aspect ratio mismatch: {example['aspect_ratio']} != {batch[0][0]['aspect_ratio']}" + ) for idx, latent in enumerate(latents): # Are there any inf or nan positions? if torch.isnan(latent).any() or torch.isinf(latent).any(): @@ -249,7 +255,7 @@ def collate_fn(batch): debug_log("Compute latents") latent_batch = compute_latents(filepaths, data_backend_id) debug_log("Check latents") - latent_batch = check_latent_shapes(latent_batch, filepaths, data_backend_id) + latent_batch = check_latent_shapes(latent_batch, filepaths, data_backend_id, batch) # Compute embeddings and handle dropped conditionings debug_log("Extract captions") diff --git a/tests/test_collate.py b/tests/test_collate.py index fbbc19f0..1474d74f 100644 --- a/tests/test_collate.py +++ b/tests/test_collate.py @@ -21,6 +21,7 @@ def setUp(self): "image_data": MagicMock(), "crop_coordinates": [0, 0, 100, 100], "data_backend_id": "foo", + "aspect_ratio": 1.0, }, # Add more examples as needed ] From 784b70aa8148f64fcbed9a86ce768e51a0dd31ae Mon Sep 17 00:00:00 2001 From: bghira Date: Sat, 13 Apr 2024 13:01:01 -0600 Subject: [PATCH 14/18] debugging aspect selection --- helpers/multiaspect/dataset.py | 7 +++++++ helpers/multiaspect/sampler.py | 6 ++++++ 2 files changed, 13 insertions(+) diff --git a/helpers/multiaspect/dataset.py b/helpers/multiaspect/dataset.py index 7918d048..938470ae 100644 --- a/helpers/multiaspect/dataset.py +++ b/helpers/multiaspect/dataset.py @@ -29,8 +29,15 @@ def __len__(self): def __getitem__(self, image_tuple): output_data = [] + first_aspect_ratio = None for sample in image_tuple: image_metadata = sample + if first_aspect_ratio is None: + first_aspect_ratio = image_metadata["aspect_ratio"] + elif first_aspect_ratio != image_metadata["aspect_ratio"]: + raise ValueError( + f"Aspect ratios must be the same for all images in a batch. Expected: {first_aspect_ratio}, got: {image_metadata['aspect_ratio']}" + ) if ( image_metadata["original_size"] is None or image_metadata["target_size"] is None diff --git a/helpers/multiaspect/sampler.py b/helpers/multiaspect/sampler.py index e9dc3fe3..626eb953 100644 --- a/helpers/multiaspect/sampler.py +++ b/helpers/multiaspect/sampler.py @@ -341,11 +341,17 @@ def __iter__(self): to_yield = self._validate_and_yield_images_from_samples( samples, self.buckets[self.current_bucket] ) + self.debug_log( + f"Building batch with {len(self.batch_accumulator)} samples." + ) if len(self.batch_accumulator) < self.batch_size: remaining_entries_needed = self.batch_size - len( self.batch_accumulator ) # Now we'll add only remaining_entries_needed amount to the accumulator: + self.debug_log( + f"Current bucket: {self.current_bucket}. Adding samples with aspect ratios: {[i['aspect_ratio'] for i in to_yield[:remaining_entries_needed]]}" + ) self.batch_accumulator.extend(to_yield[:remaining_entries_needed]) # If the batch is full, yield it if len(self.batch_accumulator) >= self.batch_size: From 00f64998c75cfdac1fddca538c7b019b5fb35575 Mon Sep 17 00:00:00 2001 From: bghira Date: Sat, 13 Apr 2024 16:53:38 -0600 Subject: [PATCH 15/18] parquet: bucket images correctly --- helpers/metadata/backends/parquet.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/helpers/metadata/backends/parquet.py b/helpers/metadata/backends/parquet.py index 642b8422..85eec56f 100644 --- a/helpers/metadata/backends/parquet.py +++ b/helpers/metadata/backends/parquet.py @@ -377,9 +377,9 @@ def _process_for_bucket( ) # Create a new bucket if it doesn't exist - if str(aspect_ratio) not in aspect_ratio_bucket_indices: - aspect_ratio_bucket_indices[str(aspect_ratio)] = [] - aspect_ratio_bucket_indices[str(aspect_ratio)].append(image_path_str) + if str(new_aspect_ratio) not in aspect_ratio_bucket_indices: + aspect_ratio_bucket_indices[str(new_aspect_ratio)] = [] + aspect_ratio_bucket_indices[str(new_aspect_ratio)].append(image_path_str) # Instead of directly updating, just fill the provided dictionary if metadata_updates is not None: metadata_updates[image_path_str] = image_metadata From a122b0a4518217800434677b6a4c976ea027ef38 Mon Sep 17 00:00:00 2001 From: bghira Date: Sat, 13 Apr 2024 16:57:57 -0600 Subject: [PATCH 16/18] json: optimise calculation of aspect bucket by removing re-calculation --- helpers/metadata/backends/json.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/helpers/metadata/backends/json.py b/helpers/metadata/backends/json.py index 12b9d9df..7a9e50c8 100644 --- a/helpers/metadata/backends/json.py +++ b/helpers/metadata/backends/json.py @@ -229,19 +229,16 @@ def _process_for_bucket( image_metadata["crop_coordinates"] = crop_coordinates image_metadata["target_size"] = image.size # Round to avoid excessive unique buckets - aspect_ratio = MultiaspectImage.calculate_image_aspect_ratio( - image, aspect_ratio_rounding - ) - image_metadata["aspect_ratio"] = aspect_ratio + image_metadata["aspect_ratio"] = new_aspect_ratio image_metadata["luminance"] = calculate_luminance(image) logger.debug( - f"Image {image_path_str} has aspect ratio {aspect_ratio} and size {image.size}." + f"Image {image_path_str} has aspect ratio {new_aspect_ratio} and size {image.size}." ) # Create a new bucket if it doesn't exist - if str(aspect_ratio) not in aspect_ratio_bucket_indices: - aspect_ratio_bucket_indices[str(aspect_ratio)] = [] - aspect_ratio_bucket_indices[str(aspect_ratio)].append(image_path_str) + if str(new_aspect_ratio) not in aspect_ratio_bucket_indices: + aspect_ratio_bucket_indices[str(new_aspect_ratio)] = [] + aspect_ratio_bucket_indices[str(new_aspect_ratio)].append(image_path_str) # Instead of directly updating, just fill the provided dictionary if metadata_updates is not None: metadata_updates[image_path_str] = image_metadata From 5771e72435187fa5ca84ae61dd2ba81717383d2d Mon Sep 17 00:00:00 2001 From: bghira Date: Sun, 14 Apr 2024 16:39:08 -0600 Subject: [PATCH 17/18] llava: add 1.6 support, though it is currently bugged out in transformers somehow --- toolkit/captioning/caption_with_llava.py | 77 ++++++++++++++++++------ 1 file changed, 59 insertions(+), 18 deletions(-) diff --git a/toolkit/captioning/caption_with_llava.py b/toolkit/captioning/caption_with_llava.py index 8b8583b6..2b641b2d 100644 --- a/toolkit/captioning/caption_with_llava.py +++ b/toolkit/captioning/caption_with_llava.py @@ -2,8 +2,15 @@ from PIL import Image from tqdm import tqdm import requests, io -from transformers import BitsAndBytesConfig +try: + from transformers import BitsAndBytesConfig +except: + pass +try: + from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration +except: + pass from transformers import AutoProcessor, LlavaForConditionalGeneration logger = logging.getLogger("Captioner") @@ -164,31 +171,65 @@ def process_parquet_dataset(args, model, processor): def load_llava_model( model_path: str = "llava-hf/llava-1.5-7b-hf", precision: str = "fp4" ): - bnb_config = BitsAndBytesConfig() - if precision == "fp4": - bnb_config = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_use_double_quant=True, - bnb_4bit_quant_type="nf4", - bnb_4bit_compute_dtype=torch.bfloat16, + try: + bnb_config = BitsAndBytesConfig() + if precision == "fp4": + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, + ) + elif precision == "fp8": + bnb_config = BitsAndBytesConfig( + load_in_8bit=True, + ) + torch_dtype = None + except: + bnb_config = None + torch_dtype = torch.float16 + if "1.6" in model_path: + model = LlavaNextForConditionalGeneration.from_pretrained( + model_path, + quantization_config=bnb_config, + torch_dtype=torch_dtype, + device_map="auto", ) - elif precision == "fp8": - bnb_config = BitsAndBytesConfig( - load_in_8bit=True, + else: + model = LlavaForConditionalGeneration.from_pretrained( + model_path, + quantization_config=bnb_config, + torch_dtype=torch_dtype, + device_map="auto", ) - model = LlavaForConditionalGeneration.from_pretrained( - model_path, quantization_config=bnb_config, device_map="auto" - ) - processor = AutoProcessor.from_pretrained(model_path) + if "1.6" in model_path: + autoprocessor_cls = LlavaNextProcessor + else: + autoprocessor_cls = AutoProcessor + processor = autoprocessor_cls.from_pretrained(model_path) return model, processor # Function to evaluate the model def eval_model(args, image_file, model, processor): - images = [image_file] - prompt = f"\nUSER: {args.query_str}\nASSISTANT:" - inputs = processor(text=prompt, images=images, return_tensors="pt").to("cuda") + + if type(processor) == LlavaNextProcessor: + logging.info("Using LLaVA 1.6+ model.") + prompt = f"<|im_start|>system\nAnswer the question carefully, without speculation.<|im_end|><|im_start|>user\n\n{args.query_str}<|im_end|><|im_start|>assistant\n" + inputs = processor(text=prompt, images=image_file, return_tensors="pt").to( + "cuda" + if torch.cuda.is_available() + else "mps" if torch.backends.mps.is_available() else "cpu" + ) + else: + prompt = f"\nUSER: {args.query_str}\nASSISTANT:" + images = [image_file] + inputs = processor(text=prompt, images=images, return_tensors="pt").to( + "cuda" + if torch.cuda.is_available() + else "mps" if torch.backends.mps.is_available() else "cpu" + ) with torch.inference_mode(): generate_ids = model.generate( **inputs, From 7ab0fba7f25d166b133f7c6bd8650f95db9836f1 Mon Sep 17 00:00:00 2001 From: bghira Date: Sun, 14 Apr 2024 17:27:12 -0600 Subject: [PATCH 18/18] bump max grad norm to 2 by default --- helpers/arguments.py | 8 +++++++- kohya_config.py | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/helpers/arguments.py b/helpers/arguments.py index d2d98de4..cc2ff36e 100644 --- a/helpers/arguments.py +++ b/helpers/arguments.py @@ -728,7 +728,13 @@ def parse_args(input_args=None): help="Whether or not to use stochastic bf16 in Adam. Currently the only supported optimizer.", ) parser.add_argument( - "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." + "--max_grad_norm", + default=2.0, + type=float, + help=( + "Clipping the max gradient norm can help prevent exploding gradients, but" + " may also harm training by introducing artifacts or making it hard to train artifacts away." + ), ) parser.add_argument( "--push_to_hub", diff --git a/kohya_config.py b/kohya_config.py index 886411e3..c1564dbe 100644 --- a/kohya_config.py +++ b/kohya_config.py @@ -197,7 +197,7 @@ "scale_weight_norms": { "parameter": "max_grad_norm", "warn_if_value": 0, - "warning": "In SimpleTuner, max_grad_norm is set to 1 by default. Please change this if you see issues.", + "warning": "In SimpleTuner, max_grad_norm is set to 2 by default. Please change this if you see issues.", }, "sdxl": {"script_name": "train_sdxl.py"}, "sdxl_cache_text_encoder_outputs": {