Skip to content

Commit

Permalink
Merge pull request #858 from bghira/main
Browse files Browse the repository at this point in the history
merge
  • Loading branch information
bghira authored Aug 24, 2024
2 parents ea2a536 + fd78777 commit 4d52348
Show file tree
Hide file tree
Showing 11 changed files with 320 additions and 79 deletions.
5 changes: 3 additions & 2 deletions config/config.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,10 @@ export TRACKER_PROJECT_NAME="${MODEL_TYPE}-training"
export TRACKER_RUN_NAME="simpletuner-sdxl"

# Max number of steps OR epochs can be used. Not both.
export MAX_NUM_STEPS=30000
# You MUST uncomment and set one of these.
#export MAX_NUM_STEPS=30000
# Will likely overtrain, but that's fine.
export NUM_EPOCHS=0
#export NUM_EPOCHS=0

# A convenient prefix for all of your training paths.
# These may be absolute or relative paths. Here, we are using relative paths.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import fnmatch
import io
from datetime import datetime
from urllib.request import url2pathname
import hashlib

import pandas as pd
import requests
from PIL import Image

from helpers.data_backend.base import BaseDataBackend
from helpers.image_manipulation.load import load_image
Expand All @@ -28,23 +25,21 @@ def url_to_filename(url: str) -> str:
return url.split("/")[-1]


def shorten_and_clean_filename(filename: str, no_op: bool):
if no_op:
return filename
filename = filename.replace("%20", "-").replace(" ", "-")
if len(filename) > 250:
filename = filename[:120] + "---" + filename[126:]
return filename
def str_hash(filename: str) -> str:
return str(hashlib.sha256(str(filename).encode()).hexdigest())


def html_to_file_loc(parent_directory: Path, url: str, shorten_filenames: bool) -> str:
def path_to_hashed_path(path: Path, hash_filenames: bool) -> Path:
path = Path(path).resolve()
if hash_filenames:
return path.parent.joinpath(str_hash(path.stem) + path.suffix)
return path


def html_to_file_loc(parent_directory: Path, url: str, hash_filenames: bool) -> str:
filename = url_to_filename(url)
cached_loc = str(
parent_directory.joinpath(
shorten_and_clean_filename(filename, no_op=shorten_filenames)
)
)
return cached_loc
cached_loc = path_to_hashed_path(parent_directory.joinpath(filename), hash_filenames)
return str(cached_loc.resolve())


class CSVDataBackend(BaseDataBackend):
Expand All @@ -54,29 +49,28 @@ def __init__(
id: str,
csv_file: Path,
compress_cache: bool = False,
image_url_col: str = "url",
caption_column: str = "caption",
url_column: str = "url",
caption_column: str = "caption",
image_cache_loc: Optional[str] = None,
shorten_filenames: bool = False,
hash_filenames: bool = True,
):
self.id = id
self.type = "csv"
self.compress_cache = compress_cache
self.shorten_filenames = shorten_filenames
self.hash_filenames = hash_filenames
self.csv_file = csv_file
self.accelerator = accelerator
self.image_url_col = image_url_col
self.df = pd.read_csv(csv_file, index_col=image_url_col)
self.url_column = url_column
self.df = pd.read_csv(csv_file, index_col=url_column)
self.df = self.df.groupby(level=0).last() # deduplicate by index (image loc)
self.caption_column = caption_column
self.url_column = url_column
self.image_cache_loc = (
Path(image_cache_loc) if image_cache_loc is not None else None
)

def read(self, location, as_byteIO: bool = False):
"""Read and return the content of the file."""
already_hashed = False
if isinstance(location, Path):
location = str(location.resolve())
if location.startswith("http"):
Expand All @@ -85,11 +79,12 @@ def read(self, location, as_byteIO: bool = False):
cached_loc = html_to_file_loc(
self.image_cache_loc,
location,
shorten_filenames=self.shorten_filenames,
self.hash_filenames,
)
if os.path.exists(cached_loc):
# found cache
location = cached_loc
already_hashed = True
else:
# actually go to website
data = requests.get(location, stream=True).raw.data
Expand All @@ -99,8 +94,13 @@ def read(self, location, as_byteIO: bool = False):
data = requests.get(location, stream=True).raw.data
if not location.startswith("http"):
# read from local file
with open(location, "rb") as file:
data = file.read()
hashed_location = path_to_hashed_path(location, hash_filenames=self.hash_filenames and not already_hashed)
try:
with open(hashed_location, "rb") as file:
data = file.read()
except FileNotFoundError as e:
tqdm.write(f'ask was for file {location} bound to {hashed_location}')
raise e
if not as_byteIO:
return data
return BytesIO(data)
Expand All @@ -114,9 +114,7 @@ def write(self, filepath: Union[str, Path], data: Any) -> None:
filepath = Path(filepath)
# Not a huge fan of auto-shortening filenames, as we hash things for that in other cases.
# However, this is copied in from the original Arcade-AI CSV backend implementation for compatibility.
filepath = filepath.parent.joinpath(
shorten_and_clean_filename(filepath.name, no_op=self.shorten_filenames)
)
filepath = path_to_hashed_path(filepath, self.hash_filenames)
filepath.parent.mkdir(parents=True, exist_ok=True)
with open(filepath, "wb") as file:
# Check if data is a Tensor, and if so, save it appropriately
Expand All @@ -137,6 +135,7 @@ def delete(self, filepath):
if filepath in self.df.index:
self.df.drop(filepath, inplace=True)
# self.save_state()
filepath = path_to_hashed_path(filepath, self.hash_filenames)
if os.path.exists(filepath):
logger.debug(f"Deleting file: {filepath}")
os.remove(filepath)
Expand All @@ -146,13 +145,15 @@ def delete(self, filepath):

def exists(self, filepath):
"""Check if the file exists."""
if isinstance(filepath, Path):
filepath = str(filepath.resolve())
return filepath in self.df.index or os.path.exists(filepath)
if isinstance(filepath, str) and "http" in filepath:
return filepath in self.df.index
else:
filepath = path_to_hashed_path(filepath, self.hash_filenames)
return os.path.exists(filepath)

def open_file(self, filepath, mode):
"""Open the file in the specified mode."""
return open(filepath, mode)
return open(path_to_hashed_path(filepath, self.hash_filenames), mode)

def list_files(
self, file_extensions: list = None, instance_data_dir: str = None
Expand Down Expand Up @@ -291,9 +292,9 @@ def torch_save(self, data, location: Union[str, Path, BytesIO]):
"""
Save a torch tensor to a file.
"""

if isinstance(location, str) or isinstance(location, Path):
if location not in self.df.index:
self.df.loc[location] = pd.Series()
location = path_to_hashed_path(location, self.hash_filenames)
location = self.open_file(location, "wb")

if self.compress_cache:
Expand All @@ -309,7 +310,7 @@ def write_batch(self, filepaths: list, data_list: list) -> None:
self.write(filepath, data)

def save_state(self):
self.df.to_csv(self.csv_file, index_label=self.image_url_col)
self.df.to_csv(self.csv_file, index_label=self.url_column)

def get_caption(self, image_path: str) -> str:
if self.caption_column is None:
Expand Down
20 changes: 14 additions & 6 deletions helpers/data_backend/factory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from helpers.data_backend.local import LocalDataBackend
from helpers.data_backend.aws import S3DataBackend
from helpers.data_backend.csv import CSVDataBackend
from helpers.data_backend.csv_url_list import CSVDataBackend
from helpers.data_backend.base import BaseDataBackend
from helpers.training.default_settings import default, latest_config_version
from helpers.caching.text_embeds import TextEmbeddingCache
Expand Down Expand Up @@ -100,6 +100,8 @@ def init_backend_config(backend: dict, args: dict, accelerator) -> dict:
output["config"]["csv_file"] = backend["csv_file"]
if "csv_caption_column" in backend:
output["config"]["csv_caption_column"] = backend["csv_caption_column"]
if "csv_url_column" in backend:
output["config"]["csv_url_column"] = backend["csv_url_column"]
if "crop_aspect" in backend:
choices = ["square", "preserve", "random", "closest"]
if backend.get("crop_aspect", None) not in choices:
Expand Down Expand Up @@ -148,8 +150,8 @@ def init_backend_config(backend: dict, args: dict, accelerator) -> dict:
)
if "hash_filenames" in backend:
output["config"]["hash_filenames"] = backend["hash_filenames"]
if "shorten_filenames" in backend and backend.get("type") == "csv":
output["config"]["shorten_filenames"] = backend["shorten_filenames"]
if "hash_filenames" in backend and backend.get("type") == "csv":
output["config"]["hash_filenames"] = backend["hash_filenames"]

# check if caption_strategy=parquet with metadata_backend=json
if (
Expand Down Expand Up @@ -635,7 +637,7 @@ def configure_multi_databackend(
csv_file=backend["csv_file"],
csv_cache_dir=backend["csv_cache_dir"],
compress_cache=args.compress_disk_cache,
shorten_filenames=backend.get("shorten_filenames", False),
hash_filenames=backend.get("hash_filenames", False),
)
# init_backend["instance_data_dir"] = backend.get("instance_data_dir", backend.get("instance_data_root", backend.get("csv_cache_dir")))
init_backend["instance_data_dir"] = None
Expand Down Expand Up @@ -721,7 +723,7 @@ def configure_multi_databackend(
delete_unwanted_images=backend.get(
"delete_unwanted_images", args.delete_unwanted_images
),
cache_file_suffix=backend.get("cache_file_suffix", None),
cache_file_suffix=backend.get("cache_file_suffix", init_backend["id"]),
repeats=init_backend["config"].get("repeats", 0),
**metadata_backend_args,
)
Expand Down Expand Up @@ -1092,8 +1094,10 @@ def get_csv_backend(
id: str,
csv_file: str,
csv_cache_dir: str,
url_column: str,
caption_column: str,
compress_cache: bool = False,
shorten_filenames: bool = False,
hash_filenames: bool = False,
) -> CSVDataBackend:
from pathlib import Path

Expand All @@ -1102,8 +1106,11 @@ def get_csv_backend(
id=id,
csv_file=Path(csv_file),
image_cache_loc=csv_cache_dir,
url_column=url_column,
caption_column=caption_column,
compress_cache=compress_cache,
shorten_filenames=shorten_filenames,
hash_filenames=hash_filenames,
)


Expand All @@ -1112,6 +1119,7 @@ def check_csv_config(backend: dict, args) -> None:
"csv_file": "This is the path to the CSV file containing your image URLs.",
"csv_cache_dir": "This is the path to your temporary cache files where images will be stored. This can grow quite large.",
"csv_caption_column": "This is the column in your csv which contains the caption(s) for the samples.",
"csv_url_column": "This is the column in your csv that contains image urls or paths.",
}
for key in required_keys.keys():
if key not in backend:
Expand Down
4 changes: 2 additions & 2 deletions helpers/metadata/backends/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,8 +467,8 @@ def _process_for_bucket(
statistics["skipped"]["metadata_missing"] += 1
return aspect_ratio_bucket_indices

width_column = self.parquet_config.get("width_column")
height_column = self.parquet_config.get("height_column")
width_column = self.parquet_config.get("width_column", "width")
height_column = self.parquet_config.get("height_column", "height")
if width_column is None or height_column is None:
raise ValueError(
"ParquetMetadataBackend requires width and height columns to be defined."
Expand Down
35 changes: 31 additions & 4 deletions helpers/training/optimizer_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
logger.setLevel(target_level)

is_optimi_available = False
from helpers.training.adam_bfloat16 import AdamWBF16
from helpers.training.optimizers.adamw_bfloat16 import AdamWBF16
from helpers.training.optimizers.adamw_schedulefree import AdamWScheduleFreeKahan

try:
from optimum.quanto import QTensor
Expand All @@ -35,6 +36,17 @@
},
"class": AdamWBF16,
},
"adamw_schedulefree": {
"precision": "any",
"override_lr_scheduler": True,
"can_warmup": True,
"default_settings": {
"betas": (0.9, 0.999),
"weight_decay": 1e-2,
"eps": 1e-8,
},
"class": AdamWScheduleFreeKahan,
},
"optimi-stableadamw": {
"precision": "any",
"default_settings": {
Expand Down Expand Up @@ -154,8 +166,8 @@
}

deprecated_optimizers = {
"prodigy": "Prodigy optimiser has been removed due to issues with precision levels and convergence. Please use optimi-stableadamw or optimi-lion instead - for decoupled learning rate, use --optimizer_config=decoupled_lr=True.",
"dadaptation": "D-adaptation optimiser has been removed due to issues with precision levels and convergence. Please use optimi-stableadamw instead.",
"prodigy": "Prodigy optimiser has been removed due to issues with precision levels and convergence. Please use adamw_schedulefree instead.",
"dadaptation": "D-adaptation optimiser has been removed due to issues with precision levels and convergence. Please use adamw_schedulefree instead.",
"adafactor": "Adafactor optimiser has been removed in favour of optimi-stableadamw, which offers improved memory efficiency and convergence.",
"adamw8bit": "AdamW8Bit has been removed in favour of optimi-adamw optimiser, which offers better low-precision support. Please use this or adamw_bf16 instead.",
}
Expand Down Expand Up @@ -206,6 +218,16 @@ def optimizer_parameters(optimizer, args):
raise ValueError(f"Optimizer {optimizer} not found.")


def is_lr_scheduler_disabled(optimizer: str):
"""Check if the optimizer has a built-in LR scheduler"""
is_disabled = False
if optimizer in optimizer_choices:
is_disabled = optimizer_choices.get(optimizer).get(
"override_lr_scheduler", False
)
return is_disabled


def show_optimizer_defaults(optimizer: str = None):
"""we'll print the defaults on a single line, eg. foo=bar, buz=baz"""
if optimizer is None:
Expand Down Expand Up @@ -260,7 +282,12 @@ def determine_optimizer_class_with_config(
else:
optimizer_class, optimizer_details = optimizer_parameters(args.optimizer, args)
default_settings = optimizer_details.get("default_settings")
logger.info(f"cls: {optimizer_class}, settings: {default_settings}")
if optimizer_details.get("can_warmup", False):
logger.info(
f"Optimizer contains LR scheduler, warmup steps will be set to {args.lr_warmup_steps}."
)
default_settings["warmup_steps"] = args.lr_warmup_steps
logger.info(f"cls: {optimizer_class}, settings: {default_settings}")
return default_settings, optimizer_class


Expand Down
File renamed without changes.
Loading

0 comments on commit 4d52348

Please sign in to comment.