Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bugfix: bucket search for unseen images should prepend the instance data root so that the images can actually be loaded from disk | image loading should use CV2 as primary method with PIL as fallback | batch prefetch should be correctly destroyed/shutdown upon error or ctrl+c | VAE embed inconsistency fixed by cloning latent before write #552

Merged
merged 20 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions INSTALL.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,7 @@ bash train_sdxl.sh > /path/to/training-$(date +%s).log 2>&1
```

> ⚠️ At this point, the commands will work, but further configuration is required. See [the tutorial](/TUTORIAL.md) for more information.

### Run unit tests

To run unit tests to ensure that installation has completed successfully, execute the command `poetry run python -m unittest discover tests/`.
17 changes: 8 additions & 9 deletions helpers/caching/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,16 +601,13 @@ def _write_latents_in_batch(self, input_latents: list = None):
else:
output_file, filepath, latent_vector = self.write_queue.get()
file_extension = os.path.splitext(output_file)[1]
if (
file_extension == ".png"
or file_extension == ".jpg"
or file_extension == ".jpeg"
):
if file_extension != ".pt":
raise ValueError(
f"Cannot write a latent embedding to an image path, {output_file}"
)
filepaths.append(output_file)
latents.append(latent_vector)
# pytorch will hold onto all of the tensors in the list if we do not use clone()
latents.append(latent_vector.clone())

self.data_backend.write_batch(filepaths, latents)

Expand Down Expand Up @@ -1106,9 +1103,11 @@ def process_buckets(self):
futures.append(future_to_write)

futures = self._process_futures(futures, executor)
import json

logger.info(f"Bucket {bucket} caching results: {statistics}")
log_msg = (
f"(id={self.id}) Bucket {bucket} caching results: {statistics}"
)
logger.debug(log_msg)
tqdm.write(log_msg)
self.debug_log(
"Completed process_buckets, all futures have been returned."
)
Expand Down
6 changes: 3 additions & 3 deletions helpers/data_backend/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import concurrent.futures
from botocore.config import Config
from helpers.data_backend.base import BaseDataBackend
from PIL import Image
from helpers.image_manipulation.load import load_image
from io import BytesIO

loggers_to_silence = [
Expand Down Expand Up @@ -247,7 +247,7 @@ def list_files(self, str_pattern: str, instance_data_root: str = None):
return results

def read_image(self, s3_key):
return Image.open(BytesIO(self.read(s3_key)))
return load_image(BytesIO(self.read(s3_key)))

def read_image_batch(self, s3_keys: list, delete_problematic_images: bool = False):
"""
Expand All @@ -265,7 +265,7 @@ def read_image_batch(self, s3_keys: list, delete_problematic_images: bool = Fals
available_keys = []
for s3_key, data in zip(s3_keys, batch):
try:
image_data = Image.open(BytesIO(data))
image_data = load_image(BytesIO(data))
if image_data is None:
logger.warning(f"Unable to load image '{s3_key}', skipping.")
continue
Expand Down
10 changes: 7 additions & 3 deletions helpers/data_backend/local.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from helpers.data_backend.base import BaseDataBackend
from helpers.image_manipulation.load import load_image
from pathlib import Path
from io import BytesIO
import os, logging, torch, gzip
from typing import Any
from PIL import Image

logger = logging.getLogger("LocalDataBackend")
logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO"))
Expand Down Expand Up @@ -114,7 +114,7 @@ def read_image(self, filepath: str, delete_problematic_images: bool = False):
# Remove embedded null byte:
filepath = filepath.replace("\x00", "")
try:
image = Image.open(filepath)
image = load_image(filepath)
return image
except Exception as e:
import traceback
Expand Down Expand Up @@ -183,7 +183,11 @@ def torch_load(self, filename):
logger.error(
f"Failed to decompress torch file, falling back to passthrough: {e}"
)
loaded_tensor = torch.load(stored_tensor, map_location="cpu")
try:
loaded_tensor = torch.load(stored_tensor, map_location="cpu")
except Exception as e:
logger.error(f"Failed to load torch file '{filename}': {e}")
raise e
return loaded_tensor

def torch_save(self, data, original_location):
Expand Down
82 changes: 82 additions & 0 deletions helpers/image_manipulation/load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import logging

from io import BytesIO
from typing import Union, IO, Any

import cv2
import numpy as np

from PIL import Image, PngImagePlugin


logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)


LARGE_ENOUGH_NUMBER = 100
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)


def decode_image_with_opencv(nparr: np.ndarray) -> Union[Image.Image, None]:
img_cv = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if img_cv is not None:
img_cv = cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB)
# Ensuring we only convert to RGB if needed.
if len(img_cv.shape) == 2 or (img_cv.shape[2] != 3 and img_cv.shape[2] == 1):
img_cv = cv2.cvtColor(img_cv, cv2.COLOR_GRAY2RGB)
return img_cv if img_cv is None else Image.fromarray(img_cv)


def decode_image_with_pil(img_data: bytes) -> Image.Image:
try:
if isinstance(img_data, bytes):
img_pil = Image.open(BytesIO(img_data))
else:
img_pil = Image.open(img_data)

if img_pil.mode not in ["RGB", "RGBA"] and "transparency" in img_pil.info:
img_pil = img_pil.convert("RGBA")

# For transparent images, add a white background as this is correct
# most of the time.
if img_pil.mode == 'RGBA':
canvas = Image.new("RGBA", img_pil.size, (255, 255, 255))
canvas.alpha_composite(img_pil)
img_pil = canvas.convert("RGB")
else:
img_pil = img_pil.convert('RGB')
except (OSError, Image.DecompressionBombError, ValueError) as e:
logger.warning(f'Error decoding image: {e}')
raise
return img_pil


def load_image(img_data: Union[bytes, IO[Any], str]) -> Image.Image:
'''
Load an image using CV2. If that fails, fall back to PIL.

The image is returned as a PIL object.
'''
if isinstance(img_data, str):
with open(img_data, 'rb') as file:
img_data = file.read()
elif hasattr(img_data, 'read'):
# Check if it's file-like object.
img_data = img_data.read()

# Preload the image bytes with channels unchanged and ensure determine
# if the image has an alpha channel. If it does we should add a white
# background to it using PIL.
nparr = np.frombuffer(img_data, np.uint8)
image_preload = cv2.imdecode(nparr, cv2.IMREAD_UNCHANGED)
has_alpha = False
if image_preload is not None and image_preload.shape[2] == 4:
has_alpha = True
del image_preload

img = None
if not has_alpha:
img = decode_image_with_opencv(nparr)
if img is None:
img = decode_image_with_pil(img_data)
return img
2 changes: 1 addition & 1 deletion helpers/image_manipulation/training_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,7 @@ def get_conditioning_image(self):
conditioning_dataset = StateTracker.get_conditioning_dataset(
data_backend_id=self.data_backend_id
)
print(f"Conditioning dataset components: {conditioning_dataset.keys()}")
# print(f"Conditioning dataset components: {conditioning_dataset.keys()}")

def cache_path(self):
metadata_backend = StateTracker.get_data_backend(self.data_backend_id)[
Expand Down
28 changes: 22 additions & 6 deletions helpers/metadata/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def compute_aspect_ratio_bucket_indices(self, ignore_existing_cache: bool = Fals
logger.info(f"Image processing statistics: {aggregated_statistics}")
self.save_image_metadata()
self.save_cache(enforce_constraints=True)
logger.info("Completed aspect bucket update.")
logger.info(f"Completed aspect bucket update.")

def split_buckets_between_processes(self, gradient_accumulation_steps=1):
"""
Expand All @@ -320,6 +320,11 @@ def split_buckets_between_processes(self, gradient_accumulation_steps=1):
num_batches = len(images) // effective_batch_size
trimmed_images = images[: num_batches * effective_batch_size]
logger.debug(f"Trimmed from {len(images)} to {len(trimmed_images)}")
if len(trimmed_images) == 0:
logger.error(
f"Bucket {bucket} has no images after trimming because {len(images)} images are not enough to satisfy an effective batch size of {effective_batch_size}."
" Lower your batch size, increase repeat count, or increase data pool size."
)

with self.accelerator.split_between_processes(
trimmed_images, apply_padding=False
Expand Down Expand Up @@ -430,9 +435,10 @@ def _enforce_min_bucket_size(self):
): # Safe iteration over keys
# Prune the smaller buckets so that we don't enforce resolution constraints on them unnecessarily.
self._prune_small_buckets(bucket)
self._enforce_resolution_constraints(bucket)
# We do this twice in case there were any new contenders for being too small.
self._prune_small_buckets(bucket)
if self.minimum_image_size is not None:
self._enforce_resolution_constraints(bucket)
# We do this twice in case there were any new contenders for being too small.
self._prune_small_buckets(bucket)

def _prune_small_buckets(self, bucket):
"""
Expand All @@ -442,8 +448,11 @@ def _prune_small_buckets(self, bucket):
bucket in self.aspect_ratio_bucket_indices
and len(self.aspect_ratio_bucket_indices[bucket]) < self.batch_size
):
bucket_sample_count = len(self.aspect_ratio_bucket_indices[bucket])
del self.aspect_ratio_bucket_indices[bucket]
logger.warning(f"Removed bucket {bucket} due to insufficient samples.")
logger.warning(
f"Removing bucket {bucket} due to insufficient samples; your batch size may be too large for the small quantity of data (batch_size={self.batch_size} > sample_count={bucket_sample_count})."
)

def _enforce_resolution_constraints(self, bucket):
"""
Expand All @@ -456,6 +465,7 @@ def _enforce_resolution_constraints(self, bucket):
)
return
images = self.aspect_ratio_bucket_indices[bucket]
total_before = len(images)
self.aspect_ratio_bucket_indices[bucket] = [
img
for img in images
Expand All @@ -464,6 +474,12 @@ def _enforce_resolution_constraints(self, bucket):
image=None,
)
]
total_after = len(self.aspect_ratio_bucket_indices[bucket])
total_lost = total_before - total_after
if total_lost > 0:
logger.info(
f"Had {total_before} samples before and {total_lost} that did not meet the minimum image size requirement ({self.minimum_image_size})."
)

def meets_resolution_requirements(
self,
Expand Down Expand Up @@ -502,7 +518,7 @@ def meets_resolution_requirements(
# We receive megapixel integer value, and then have to compare here by converting minimum_image_size MP to pixels.
if self.minimum_image_size > 5:
raise ValueError(
f"--minimum_image_size was given with a value of {minimum_image_size} but resolution_type is area, which means this value is most likely too large. Please use a value less than 5."
f"--minimum_image_size was given with a value of {self.minimum_image_size} but resolution_type is area, which means this value is most likely too large. Please use a value less than 5."
)
# We need to find the square image length if crop_style = square.
minimum_image_size = self.minimum_image_size * 1_000_000
Expand Down
4 changes: 2 additions & 2 deletions helpers/metadata/backends/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from helpers.data_backend.base import BaseDataBackend
from helpers.metadata.backends.base import MetadataBackend
from helpers.image_manipulation.training_sample import TrainingSample
from helpers.image_manipulation.load import load_image
from pathlib import Path
import json, logging, os, traceback
from multiprocessing import Manager
from PIL import Image
from tqdm import tqdm
from multiprocessing import Process, Queue
import numpy as np
Expand Down Expand Up @@ -220,7 +220,7 @@ def _process_for_bucket(
statistics["skipped"]["not_found"] += 1
return aspect_ratio_bucket_indices

with Image.open(BytesIO(image_data)) as image:
with load_image(BytesIO(image_data)) as image:
if not self.meets_resolution_requirements(image=image):
if not self.delete_unwanted_images:
logger.debug(
Expand Down
12 changes: 8 additions & 4 deletions helpers/metadata/backends/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,16 @@ def _extract_captions_to_fast_list(self):
filename = str(row[filename_column])
else:
filename = str(index)

if not identifier_includes_extension:
filename = os.path.splitext(filename)[0]

caption = row[caption_column]
if type(caption_column) == list:
caption = None
if len(caption_column) > 0:
caption = [row[c] for c in caption_column]
else:
caption = row[caption_column]

if not caption and fallback_caption_column:
caption = row[fallback_caption_column]
if not caption:
Expand All @@ -150,8 +155,7 @@ def _extract_captions_to_fast_list(self):
if type(caption) == bytes:
caption = caption.decode("utf-8")
elif type(caption) == list:
# selecting the first caption. see issue #476 on github
caption = caption[0]
caption = [c.strip() for c in caption if c.strip()]
if caption:
caption = caption.strip()
captions[filename] = caption
Expand Down
15 changes: 12 additions & 3 deletions helpers/multiaspect/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,11 @@ def retrieve_validation_set(self, batch_size: int):
prepend_instance_prompt=self.prepend_instance_prompt,
instance_prompt=self.instance_prompt,
)
if type(validation_prompt) == list:
validation_prompt = random.choice(validation_prompt)
self.debug_log(
f"Selecting random prompt from list: {validation_prompt}"
)
results.append(
(validation_shortname, validation_prompt, training_sample.image)
)
Expand Down Expand Up @@ -228,7 +233,7 @@ def _get_unseen_images(self, bucket=None):
"""
if bucket and bucket in self.metadata_backend.aspect_ratio_bucket_indices:
return [
image
os.path.join(self.metadata_backend.instance_data_root, image)
for image in self.metadata_backend.aspect_ratio_bucket_indices[bucket]
if not self.metadata_backend.is_seen(image)
]
Expand All @@ -237,7 +242,7 @@ def _get_unseen_images(self, bucket=None):
for b, images in self.metadata_backend.aspect_ratio_bucket_indices.items():
unseen_images.extend(
[
image
os.path.join(self.metadata_backend.instance_data_root, image)
for image in images
if not self.metadata_backend.is_seen(image)
]
Expand Down Expand Up @@ -380,7 +385,7 @@ def _validate_and_yield_images_from_samples(self, samples, bucket):
image_metadata["image_path"] = image_path

# Use the magic prompt handler to retrieve the captions.
image_metadata["instance_prompt_text"] = PromptHandler.magic_prompt(
instance_prompt = PromptHandler.magic_prompt(
sampler_backend_id=self.id,
data_backend=self.data_backend,
image_path=image_metadata["image_path"],
Expand All @@ -389,6 +394,10 @@ def _validate_and_yield_images_from_samples(self, samples, bucket):
prepend_instance_prompt=self.prepend_instance_prompt,
instance_prompt=self.instance_prompt,
)
if type(instance_prompt) == list:
instance_prompt = random.choice(instance_prompt)
self.debug_log(f"Selecting random prompt from list: {instance_prompt}")
image_metadata["instance_prompt_text"] = instance_prompt

to_yield.append(image_metadata)
return to_yield
Expand Down
Loading
Loading