From abe647f9e965f9a8dc4f80f19454311b0075fd68 Mon Sep 17 00:00:00 2001 From: William Zhuk Date: Mon, 5 Aug 2024 10:50:47 -0700 Subject: [PATCH 01/18] Update and rename csv.py to csv_.py - used hashing of filenames instead of shortening - csv_ instead of csv to avoid potential issues with importing pandas (pandas can get confused as it has its own internal csv.py) - a bit of debugging of some issues --- helpers/data_backend/{csv.py => csv_.py} | 102 ++++++++++++++--------- 1 file changed, 62 insertions(+), 40 deletions(-) rename helpers/data_backend/{csv.py => csv_.py} (79%) diff --git a/helpers/data_backend/csv.py b/helpers/data_backend/csv_.py similarity index 79% rename from helpers/data_backend/csv.py rename to helpers/data_backend/csv_.py index b73e1c2d..673f193b 100644 --- a/helpers/data_backend/csv.py +++ b/helpers/data_backend/csv_.py @@ -1,11 +1,8 @@ import fnmatch -import io -from datetime import datetime -from urllib.request import url2pathname +import hashlib import pandas as pd import requests -from PIL import Image from helpers.data_backend.base import BaseDataBackend from helpers.image_manipulation.load import load_image @@ -28,23 +25,21 @@ def url_to_filename(url: str) -> str: return url.split("/")[-1] -def shorten_and_clean_filename(filename: str, no_op: bool): - if no_op: - return filename - filename = filename.replace("%20", "-").replace(" ", "-") - if len(filename) > 250: - filename = filename[:120] + "---" + filename[126:] - return filename +def str_hash(filename: str) -> str: + return str(hashlib.sha256(str(filename).encode()).hexdigest()) -def html_to_file_loc(parent_directory: Path, url: str, shorten_filenames: bool) -> str: +def path_to_hashed_path(path: Path, hash_filenames: bool) -> Path: + path = Path(path).resolve() + if hash_filenames: + return path.parent.joinpath(str_hash(path.stem) + path.suffix) + return path + + +def html_to_file_loc(parent_directory: Path, url: str, hash_filenames: bool) -> str: filename = url_to_filename(url) - cached_loc = str( - parent_directory.joinpath( - shorten_and_clean_filename(filename, no_op=shorten_filenames) - ) - ) - return cached_loc + cached_loc = path_to_hashed_path(parent_directory.joinpath(filename), hash_filenames) + return str(cached_loc.resolve()) class CSVDataBackend(BaseDataBackend): @@ -54,29 +49,28 @@ def __init__( id: str, csv_file: Path, compress_cache: bool = False, - image_url_col: str = "url", - caption_column: str = "caption", url_column: str = "url", + caption_column: str = "caption", image_cache_loc: Optional[str] = None, - shorten_filenames: bool = False, + hash_filenames: bool = True, ): self.id = id self.type = "csv" self.compress_cache = compress_cache - self.shorten_filenames = shorten_filenames + self.hash_filenames = hash_filenames self.csv_file = csv_file self.accelerator = accelerator - self.image_url_col = image_url_col - self.df = pd.read_csv(csv_file, index_col=image_url_col) + self.url_column = url_column + self.df = pd.read_csv(csv_file, index_col=url_column) self.df = self.df.groupby(level=0).last() # deduplicate by index (image loc) self.caption_column = caption_column - self.url_column = url_column self.image_cache_loc = ( Path(image_cache_loc) if image_cache_loc is not None else None ) def read(self, location, as_byteIO: bool = False): """Read and return the content of the file.""" + already_hashed = False if isinstance(location, Path): location = str(location.resolve()) if location.startswith("http"): @@ -85,11 +79,12 @@ def read(self, location, as_byteIO: bool = False): cached_loc = html_to_file_loc( self.image_cache_loc, location, - shorten_filenames=self.shorten_filenames, + self.hash_filenames, ) if os.path.exists(cached_loc): # found cache location = cached_loc + already_hashed = True else: # actually go to website data = requests.get(location, stream=True).raw.data @@ -99,8 +94,13 @@ def read(self, location, as_byteIO: bool = False): data = requests.get(location, stream=True).raw.data if not location.startswith("http"): # read from local file - with open(location, "rb") as file: - data = file.read() + hashed_location = path_to_hashed_path(location, hash_filenames=self.hash_filenames and not already_hashed) + try: + with open(hashed_location, "rb") as file: + data = file.read() + except FileNotFoundError as e: + print(f'ask was for file {location} bound to {hashed_location}') + raise e if not as_byteIO: return data return BytesIO(data) @@ -114,9 +114,7 @@ def write(self, filepath: Union[str, Path], data: Any) -> None: filepath = Path(filepath) # Not a huge fan of auto-shortening filenames, as we hash things for that in other cases. # However, this is copied in from the original Arcade-AI CSV backend implementation for compatibility. - filepath = filepath.parent.joinpath( - shorten_and_clean_filename(filepath.name, no_op=self.shorten_filenames) - ) + filepath = path_to_hashed_path(filepath, self.hash_filenames) filepath.parent.mkdir(parents=True, exist_ok=True) with open(filepath, "wb") as file: # Check if data is a Tensor, and if so, save it appropriately @@ -137,6 +135,7 @@ def delete(self, filepath): if filepath in self.df.index: self.df.drop(filepath, inplace=True) # self.save_state() + filepath = path_to_hashed_path(filepath, self.hash_filenames) if os.path.exists(filepath): logger.debug(f"Deleting file: {filepath}") os.remove(filepath) @@ -146,20 +145,21 @@ def delete(self, filepath): def exists(self, filepath): """Check if the file exists.""" - if isinstance(filepath, Path): - filepath = str(filepath.resolve()) - return filepath in self.df.index or os.path.exists(filepath) + if isinstance(filepath, str) and "http" in filepath: + return filepath in self.df.index + else: + filepath = path_to_hashed_path(filepath, self.hash_filenames) + return os.path.exists(filepath) def open_file(self, filepath, mode): """Open the file in the specified mode.""" - return open(filepath, mode) + return open(path_to_hashed_path(filepath, self.hash_filenames), mode) def list_files(self, str_pattern: str, instance_data_dir: str = None) -> tuple: """ List all files matching the pattern. Creates Path objects of each file found. """ - # print frame contents logger.debug( f"CSVDataBackend.list_files: str_pattern={str_pattern}, instance_data_dir={instance_data_dir}" ) @@ -197,7 +197,7 @@ def read_image(self, filepath: str, delete_problematic_images: bool = False): filepath = filepath.replace("\x00", "") try: image_data = self.read(filepath, as_byteIO=True) - image = load_image(image_data) + image = load_image(image_data).resize((1024, 1024)) return image except Exception as e: import traceback @@ -279,9 +279,9 @@ def torch_save(self, data, location: Union[str, Path, BytesIO]): """ Save a torch tensor to a file. """ + if isinstance(location, str) or isinstance(location, Path): - if location not in self.df.index: - self.df.loc[location] = pd.Series() + location = path_to_hashed_path(location, self.hash_filenames) location = self.open_file(location, "wb") if self.compress_cache: @@ -297,9 +297,31 @@ def write_batch(self, filepaths: list, data_list: list) -> None: self.write(filepath, data) def save_state(self): - self.df.to_csv(self.csv_file, index_label=self.image_url_col) + self.df.to_csv(self.csv_file, index_label=self.url_column) def get_caption(self, image_path: str) -> str: if self.caption_column is None: raise ValueError("Cannot retrieve caption from csv, as one is not set.") return self.df.loc[image_path, self.caption_column] + + +if __name__ == "__main__": + data = CSVDataBackend( + None, + id="test", + csv_file=Path("/media/second8TBNVME/cache/SimpleTuner/sd3/jewelry-v15.csv"), + image_cache_loc="/media/second8TBNVME/cache/SimpleTuner/image-cache", + url_column="Image", + caption_column="Long Caption", + compress_cache=False, + hash_filenames=True, + ) + results = \ + data.list_files("*.[jJpP][pPnN][gG]", instance_data_dir="/media/second8TBNVME/cache/SimpleTuner/jewelry-v15")[0][2] + # print(results) + test = data.exists( + "https://storage.googleapis.com/internal-assets-arcade-ai-prod/xbnwoi287kc/long%20slim%20dangle%20earringss.png.txt") + for file in results: + image = data.read_image(file, delete_problematic_images=False) + print(image.size, file) + caption = data.get_caption(file) From 6f9150260a56cef06229a38b8e3dc9e80deffee4 Mon Sep 17 00:00:00 2001 From: William Zhuk Date: Mon, 5 Aug 2024 10:53:58 -0700 Subject: [PATCH 02/18] remove some parts that are non-universal --- helpers/data_backend/csv_.py | 24 +----------------------- 1 file changed, 1 insertion(+), 23 deletions(-) diff --git a/helpers/data_backend/csv_.py b/helpers/data_backend/csv_.py index 673f193b..bae0b1b4 100644 --- a/helpers/data_backend/csv_.py +++ b/helpers/data_backend/csv_.py @@ -197,7 +197,7 @@ def read_image(self, filepath: str, delete_problematic_images: bool = False): filepath = filepath.replace("\x00", "") try: image_data = self.read(filepath, as_byteIO=True) - image = load_image(image_data).resize((1024, 1024)) + image = load_image(image_data) return image except Exception as e: import traceback @@ -303,25 +303,3 @@ def get_caption(self, image_path: str) -> str: if self.caption_column is None: raise ValueError("Cannot retrieve caption from csv, as one is not set.") return self.df.loc[image_path, self.caption_column] - - -if __name__ == "__main__": - data = CSVDataBackend( - None, - id="test", - csv_file=Path("/media/second8TBNVME/cache/SimpleTuner/sd3/jewelry-v15.csv"), - image_cache_loc="/media/second8TBNVME/cache/SimpleTuner/image-cache", - url_column="Image", - caption_column="Long Caption", - compress_cache=False, - hash_filenames=True, - ) - results = \ - data.list_files("*.[jJpP][pPnN][gG]", instance_data_dir="/media/second8TBNVME/cache/SimpleTuner/jewelry-v15")[0][2] - # print(results) - test = data.exists( - "https://storage.googleapis.com/internal-assets-arcade-ai-prod/xbnwoi287kc/long%20slim%20dangle%20earringss.png.txt") - for file in results: - image = data.read_image(file, delete_problematic_images=False) - print(image.size, file) - caption = data.get_caption(file) From 52b65545109eb223ea448a62a34639f5b2f68a5a Mon Sep 17 00:00:00 2001 From: William Zhuk Date: Mon, 5 Aug 2024 10:54:21 -0700 Subject: [PATCH 03/18] Update factory.py --- helpers/data_backend/factory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/helpers/data_backend/factory.py b/helpers/data_backend/factory.py index 0b0ca503..0862c6a2 100644 --- a/helpers/data_backend/factory.py +++ b/helpers/data_backend/factory.py @@ -1,6 +1,6 @@ from helpers.data_backend.local import LocalDataBackend from helpers.data_backend.aws import S3DataBackend -from helpers.data_backend.csv import CSVDataBackend +from helpers.data_backend.csv_ import CSVDataBackend from helpers.data_backend.base import BaseDataBackend from helpers.training.default_settings import default, latest_config_version from helpers.caching.text_embeds import TextEmbeddingCache From c693c7d0b650293fed8bfff16895e0d787565a76 Mon Sep 17 00:00:00 2001 From: William Zhuk Date: Mon, 5 Aug 2024 11:23:17 -0700 Subject: [PATCH 04/18] Update factory.py --- helpers/data_backend/factory.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/helpers/data_backend/factory.py b/helpers/data_backend/factory.py index 0862c6a2..fc5f8a48 100644 --- a/helpers/data_backend/factory.py +++ b/helpers/data_backend/factory.py @@ -99,6 +99,8 @@ def init_backend_config(backend: dict, args: dict, accelerator) -> dict: output["config"]["csv_file"] = backend["csv_file"] if "csv_caption_column" in backend: output["config"]["csv_caption_column"] = backend["csv_caption_column"] + if "csv_url_column" in backend: + output["config"]["csv_url_column"] = backend["csv_url_column"] if "crop_aspect" in backend: choices = ["square", "preserve", "random"] if backend.get("crop_aspect", None) not in choices: @@ -147,8 +149,8 @@ def init_backend_config(backend: dict, args: dict, accelerator) -> dict: ) if "hash_filenames" in backend: output["config"]["hash_filenames"] = backend["hash_filenames"] - if "shorten_filenames" in backend and backend.get("type") == "csv": - output["config"]["shorten_filenames"] = backend["shorten_filenames"] + if "hash_filenames" in backend and backend.get("type") == "csv": + output["config"]["hash_filenames"] = backend["hash_filenames"] # check if caption_strategy=parquet with metadata_backend=json if ( @@ -593,7 +595,7 @@ def configure_multi_databackend( csv_file=backend["csv_file"], csv_cache_dir=backend["csv_cache_dir"], compress_cache=args.compress_disk_cache, - shorten_filenames=backend.get("shorten_filenames", False), + hash_filenames=backend.get("hash_filenames", False), ) # init_backend["instance_data_dir"] = backend.get("instance_data_dir", backend.get("instance_data_root", backend.get("csv_cache_dir"))) init_backend["instance_data_dir"] = None @@ -1039,7 +1041,7 @@ def get_csv_backend( csv_file: str, csv_cache_dir: str, compress_cache: bool = False, - shorten_filenames: bool = False, + hash_filenames: bool = False, ) -> CSVDataBackend: from pathlib import Path @@ -1048,8 +1050,11 @@ def get_csv_backend( id=id, csv_file=Path(csv_file), image_cache_loc=csv_cache_dir, + url_column=url_column, + caption_column=caption_column, compress_cache=compress_cache, shorten_filenames=shorten_filenames, + hash_filenames=hash_filenames, ) @@ -1058,6 +1063,7 @@ def check_csv_config(backend: dict, args) -> None: "csv_file": "This is the path to the CSV file containing your image URLs.", "csv_cache_dir": "This is the path to your temporary cache files where images will be stored. This can grow quite large.", "csv_caption_column": "This is the column in your csv which contains the caption(s) for the samples.", + "csv_url_column": "This is the column in your csv that contains image urls or paths.", } for key in required_keys.keys(): if key not in backend: From ff0739cd268ef44ffaa989a225fd79ce33725229 Mon Sep 17 00:00:00 2001 From: William Zhuk Date: Mon, 5 Aug 2024 11:27:17 -0700 Subject: [PATCH 05/18] Update factory.py --- helpers/data_backend/factory.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/helpers/data_backend/factory.py b/helpers/data_backend/factory.py index fc5f8a48..cca03131 100644 --- a/helpers/data_backend/factory.py +++ b/helpers/data_backend/factory.py @@ -1040,6 +1040,8 @@ def get_csv_backend( id: str, csv_file: str, csv_cache_dir: str, + url_column: str, + caption_column: str, compress_cache: bool = False, hash_filenames: bool = False, ) -> CSVDataBackend: From 614764ddd5e95a18ced057a43f5eeb8c2221056e Mon Sep 17 00:00:00 2001 From: William Zhuk Date: Mon, 19 Aug 2024 09:44:12 -0700 Subject: [PATCH 06/18] Update helpers/data_backend/csv_.py Co-authored-by: Bagheera <59658056+bghira@users.noreply.github.com> --- helpers/data_backend/csv_.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/helpers/data_backend/csv_.py b/helpers/data_backend/csv_.py index bae0b1b4..21a93664 100644 --- a/helpers/data_backend/csv_.py +++ b/helpers/data_backend/csv_.py @@ -99,7 +99,7 @@ def read(self, location, as_byteIO: bool = False): with open(hashed_location, "rb") as file: data = file.read() except FileNotFoundError as e: - print(f'ask was for file {location} bound to {hashed_location}') + tqdm.write(f'ask was for file {location} bound to {hashed_location}') raise e if not as_byteIO: return data From 9b833d22f636660b192606326069406d38ccc725 Mon Sep 17 00:00:00 2001 From: bghira Date: Thu, 22 Aug 2024 18:04:19 -0600 Subject: [PATCH 07/18] backwards-compat breaking: set cache_file_suffix by default to the id of the dataset to avoid conflicts more easily. to fix the breakage, manually set cache_file_suffix to null or just let it rebuild the buckets. sorry --- helpers/data_backend/factory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/helpers/data_backend/factory.py b/helpers/data_backend/factory.py index 7096a4de..13e57f41 100644 --- a/helpers/data_backend/factory.py +++ b/helpers/data_backend/factory.py @@ -721,7 +721,7 @@ def configure_multi_databackend( delete_unwanted_images=backend.get( "delete_unwanted_images", args.delete_unwanted_images ), - cache_file_suffix=backend.get("cache_file_suffix", None), + cache_file_suffix=backend.get("cache_file_suffix", init_backend["id"]), repeats=init_backend["config"].get("repeats", 0), **metadata_backend_args, ) From c464201d553890eb9fc31087b2ea695c92ab1e63 Mon Sep 17 00:00:00 2001 From: bghira Date: Thu, 22 Aug 2024 18:07:04 -0600 Subject: [PATCH 08/18] parquet: set width and height column value by default to width and height --- helpers/metadata/backends/parquet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/helpers/metadata/backends/parquet.py b/helpers/metadata/backends/parquet.py index fc4204d2..18aed3ff 100644 --- a/helpers/metadata/backends/parquet.py +++ b/helpers/metadata/backends/parquet.py @@ -467,8 +467,8 @@ def _process_for_bucket( statistics["skipped"]["metadata_missing"] += 1 return aspect_ratio_bucket_indices - width_column = self.parquet_config.get("width_column") - height_column = self.parquet_config.get("height_column") + width_column = self.parquet_config.get("width_column", "width") + height_column = self.parquet_config.get("height_column", "height") if width_column is None or height_column is None: raise ValueError( "ParquetMetadataBackend requires width and height columns to be defined." From 8bb8481c62c3f87c8cbaaff8666ae1f3bc6235df Mon Sep 17 00:00:00 2001 From: bghira Date: Thu, 22 Aug 2024 18:07:29 -0600 Subject: [PATCH 09/18] global shame: hide the existence of gradient accumulation steps --- train.sh | 1 - 1 file changed, 1 deletion(-) diff --git a/train.sh b/train.sh index bbc27eb0..dfd82c05 100755 --- a/train.sh +++ b/train.sh @@ -319,7 +319,6 @@ if ! [ -z "$USE_GRADIENT_CHECKPOINTING" ] && [[ "$USE_GRADIENT_CHECKPOINTING" == fi if [ -z "$GRADIENT_ACCUMULATION_STEPS" ]; then - printf "GRADIENT_ACCUMULATION_STEPS not set, defaulting to 1.\n" export GRADIENT_ACCUMULATION_STEPS=1 fi From b549f847b30a533a6128c4b78ce0d2ceb9eb08c2 Mon Sep 17 00:00:00 2001 From: Bagheera <59658056+bghira@users.noreply.github.com> Date: Fri, 23 Aug 2024 07:41:38 -0600 Subject: [PATCH 10/18] Rename csv_.py to csv_url_list.py --- helpers/data_backend/{csv_.py => csv_url_list.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename helpers/data_backend/{csv_.py => csv_url_list.py} (100%) diff --git a/helpers/data_backend/csv_.py b/helpers/data_backend/csv_url_list.py similarity index 100% rename from helpers/data_backend/csv_.py rename to helpers/data_backend/csv_url_list.py From 48426a44301753e954ad04c50523d25b1ae433c3 Mon Sep 17 00:00:00 2001 From: Bagheera <59658056+bghira@users.noreply.github.com> Date: Fri, 23 Aug 2024 07:41:53 -0600 Subject: [PATCH 11/18] Update factory.py --- helpers/data_backend/factory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/helpers/data_backend/factory.py b/helpers/data_backend/factory.py index cca03131..b53d5dcd 100644 --- a/helpers/data_backend/factory.py +++ b/helpers/data_backend/factory.py @@ -1,6 +1,6 @@ from helpers.data_backend.local import LocalDataBackend from helpers.data_backend.aws import S3DataBackend -from helpers.data_backend.csv_ import CSVDataBackend +from helpers.data_backend.csv_url_list import CSVDataBackend from helpers.data_backend.base import BaseDataBackend from helpers.training.default_settings import default, latest_config_version from helpers.caching.text_embeds import TextEmbeddingCache From 177190dc385f771bb14a4cfd696be7c082bc116c Mon Sep 17 00:00:00 2001 From: bghira Date: Fri, 23 Aug 2024 07:47:33 -0600 Subject: [PATCH 12/18] trainer: do not override NUM_EPOCHS and MAX_NUM_STEPS in config.env.example --- config/config.env.example | 5 +++-- train.sh | 19 ++++++++++++++----- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/config/config.env.example b/config/config.env.example index 236172c0..d00f74a2 100644 --- a/config/config.env.example +++ b/config/config.env.example @@ -81,9 +81,10 @@ export TRACKER_PROJECT_NAME="${MODEL_TYPE}-training" export TRACKER_RUN_NAME="simpletuner-sdxl" # Max number of steps OR epochs can be used. Not both. -export MAX_NUM_STEPS=30000 +# You MUST uncomment and set one of these. +#export MAX_NUM_STEPS=30000 # Will likely overtrain, but that's fine. -export NUM_EPOCHS=0 +#export NUM_EPOCHS=0 # A convenient prefix for all of your training paths. # These may be absolute or relative paths. Here, we are using relative paths. diff --git a/train.sh b/train.sh index dfd82c05..78d00142 100755 --- a/train.sh +++ b/train.sh @@ -108,10 +108,6 @@ if [ -z "${TRACKER_RUN_NAME}" ]; then printf "TRACKER_RUN_NAME not set, exiting.\n" exit 1 fi -if [ -z "${NUM_EPOCHS}" ]; then - printf "NUM_EPOCHS not set, exiting.\n" - exit 1 -fi if [ -z "${VALIDATION_PROMPT}" ]; then printf "VALIDATION_PROMPT not set, exiting.\n" exit 1 @@ -307,7 +303,20 @@ if ! [ -f "$DATALOADER_CONFIG" ]; then printf "DATALOADER_CONFIG file %s not found, cannot continue.\n" "${DATALOADER_CONFIG}" exit 1 fi - +if [ -z "$MAX_TRAIN_STEPS" ] && [ -z "$NUM_EPOCHS" ]; then + echo "Neither MAX_TRAIN_STEPS or NUM_EPOCHS were defined." + exit 1 +fi +if [ -z "$MAX_TRAIN_STEPS" ]; then + export MAX_TRAIN_STEPS=0 +fi +if [ -z "$NUM_EPOCHS" ]; then + export NUM_EPOCHS=0 +fi +if [[ "$MAX_TRAIN_STEPS" == "0" ]] && [[ "$NUM_EPOCHS" == "0" ]]; then + echo "Both MAX_TRAIN_STEPS and NUM_EPOCHS cannot be zero." + exit 1 +fi export SNR_GAMMA_ARG="" if [ -n "$MIN_SNR_GAMMA" ]; then export SNR_GAMMA_ARG="--snr_gamma=${MIN_SNR_GAMMA}" From 7c1f047e6e34533fe95e89235480384fc2ac5dc8 Mon Sep 17 00:00:00 2001 From: bghira Date: Fri, 23 Aug 2024 16:02:10 -0600 Subject: [PATCH 13/18] new optimiser, adamw_schedulefree --- helpers/training/optimizer_param.py | 16 +- .../adam_bfloat16/__init__.py | 0 .../adam_bfloat16/stochastic/__init__.py | 0 .../optimizers/adamw_schedulefree/__init__.py | 142 ++++++++++++++++++ 4 files changed, 155 insertions(+), 3 deletions(-) rename helpers/training/{ => optimizers}/adam_bfloat16/__init__.py (100%) rename helpers/training/{ => optimizers}/adam_bfloat16/stochastic/__init__.py (100%) create mode 100644 helpers/training/optimizers/adamw_schedulefree/__init__.py diff --git a/helpers/training/optimizer_param.py b/helpers/training/optimizer_param.py index c27d771c..f3ed219d 100644 --- a/helpers/training/optimizer_param.py +++ b/helpers/training/optimizer_param.py @@ -9,7 +9,8 @@ logger.setLevel(target_level) is_optimi_available = False -from helpers.training.adam_bfloat16 import AdamWBF16 +from helpers.training.optimizers.adamw_bfloat16 import AdamWBF16 +from helpers.training.optimizers.adamw_schedulefree import AdamWScheduleFreeKahan try: from optimum.quanto import QTensor @@ -35,6 +36,15 @@ }, "class": AdamWBF16, }, + "adamw_schedulefree": { + "precision": "any", + "default_settings": { + "betas": (0.9, 0.999), + "weight_decay": 1e-2, + "eps": 1e-8, + }, + "class": AdamWScheduleFreeKahan, + }, "optimi-stableadamw": { "precision": "any", "default_settings": { @@ -154,8 +164,8 @@ } deprecated_optimizers = { - "prodigy": "Prodigy optimiser has been removed due to issues with precision levels and convergence. Please use optimi-stableadamw or optimi-lion instead - for decoupled learning rate, use --optimizer_config=decoupled_lr=True.", - "dadaptation": "D-adaptation optimiser has been removed due to issues with precision levels and convergence. Please use optimi-stableadamw instead.", + "prodigy": "Prodigy optimiser has been removed due to issues with precision levels and convergence. Please use adamw_schedulefree instead.", + "dadaptation": "D-adaptation optimiser has been removed due to issues with precision levels and convergence. Please use adamw_schedulefree instead.", "adafactor": "Adafactor optimiser has been removed in favour of optimi-stableadamw, which offers improved memory efficiency and convergence.", "adamw8bit": "AdamW8Bit has been removed in favour of optimi-adamw optimiser, which offers better low-precision support. Please use this or adamw_bf16 instead.", } diff --git a/helpers/training/adam_bfloat16/__init__.py b/helpers/training/optimizers/adam_bfloat16/__init__.py similarity index 100% rename from helpers/training/adam_bfloat16/__init__.py rename to helpers/training/optimizers/adam_bfloat16/__init__.py diff --git a/helpers/training/adam_bfloat16/stochastic/__init__.py b/helpers/training/optimizers/adam_bfloat16/stochastic/__init__.py similarity index 100% rename from helpers/training/adam_bfloat16/stochastic/__init__.py rename to helpers/training/optimizers/adam_bfloat16/stochastic/__init__.py diff --git a/helpers/training/optimizers/adamw_schedulefree/__init__.py b/helpers/training/optimizers/adamw_schedulefree/__init__.py new file mode 100644 index 00000000..b58e8814 --- /dev/null +++ b/helpers/training/optimizers/adamw_schedulefree/__init__.py @@ -0,0 +1,142 @@ +import torch +from torch.optim.optimizer import Optimizer +import math +from typing import Iterable + + +class AdamWScheduleFreeKahan(Optimizer): + """AdamW optimizer with schedule-free adjustments and Kahan summation. + + Args: + params: Iterable of parameters to optimize or dicts defining parameter groups. + lr: Learning rate. + betas: Coefficients for gradient and squared gradient moving averages (default: (0.9, 0.999)). + eps: Added to denominator to improve numerical stability (default: 1e-8). + weight_decay: Weight decay coefficient (default: 1e-2). + warmup_steps: Number of steps to warm up the learning rate (default: 0). + kahan_sum: Enables Kahan summation for more accurate parameter updates when training in low precision. + """ + + def __init__( + self, + params: Iterable, + lr: float = 1e-3, + betas: tuple = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 1e-2, + warmup_steps: int = 0, + kahan_sum: bool = True, + ): + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + warmup_steps=warmup_steps, + kahan_sum=kahan_sum, + ) + super(AdamWScheduleFreeKahan, self).__init__(params, defaults) + self.k = 0 + self.lr_max = -1.0 + self.weight_sum = 0.0 + + def _initialize_state(self, state, p): + if "step" not in state: + state["step"] = 0 + state["exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format) + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + if self.defaults["kahan_sum"]: + state["kahan_comp"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + def eval(self): + for group in self.param_groups: + train_mode = group.get("train_mode", True) + beta1, _ = group["betas"] + if train_mode: + for p in group["params"]: + state = self.state[p] + if "z" in state: + # Set p.data to x + p.data.lerp_( + end=state["z"].to(p.data.device), weight=1 - 1 / beta1 + ) + group["train_mode"] = False + + def train(self): + for group in self.param_groups: + train_mode = group.get("train_mode", False) + beta1, _ = group["betas"] + if not train_mode: + for p in group["params"]: + state = self.state[p] + if "z" in state: + # Set p.data to y + p.data.lerp_(end=state["z"].to(p.data.device), weight=1 - beta1) + group["train_mode"] = True + + def step(self, closure=None): + """Performs a single optimization step.""" + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + beta1, beta2 = group["betas"] + lr = group["lr"] + eps = group["eps"] + weight_decay = group["weight_decay"] + warmup_steps = group["warmup_steps"] + kahan_sum = group["kahan_sum"] + + k = self.k + + # Adjust learning rate with warmup + if k < warmup_steps: + sched = (k + 1) / warmup_steps + else: + sched = 1.0 + + bias_correction2 = 1 - beta2 ** (k + 1) + adjusted_lr = lr * sched * (bias_correction2**0.5) + self.lr_max = max(adjusted_lr, self.lr_max) + + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.data + + state = self.state[p] + self._initialize_state(state, p) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + + if kahan_sum: + kahan_comp = state["kahan_comp"] + grad.add_(kahan_comp) + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + denom = exp_avg_sq.sqrt().add_(eps) + + step_size = adjusted_lr / (bias_correction2**0.5) + + if weight_decay != 0: + p.data.add_(p.data, alpha=-weight_decay) + + # Kahan summation to improve precision + step = exp_avg / denom + p.data.add_(-step_size * step) + + if kahan_sum: + buffer = p.data.add(-step_size * step) + kahan_comp.copy_(p.data.sub(buffer).add(buffer.sub_(p.data))) + + self.k += 1 + + return loss From 0cc04402801dbee75281d6ea7439d58afe50e5dd Mon Sep 17 00:00:00 2001 From: bghira Date: Fri, 23 Aug 2024 16:14:28 -0600 Subject: [PATCH 14/18] schedulefree integration --- helpers/training/optimizer_param.py | 8 +++ .../__init__.py | 0 .../stochastic/__init__.py | 0 train.py | 51 +++++++++++++++---- 4 files changed, 48 insertions(+), 11 deletions(-) rename helpers/training/optimizers/{adam_bfloat16 => adamw_bfloat16}/__init__.py (100%) rename helpers/training/optimizers/{adam_bfloat16 => adamw_bfloat16}/stochastic/__init__.py (100%) diff --git a/helpers/training/optimizer_param.py b/helpers/training/optimizer_param.py index f3ed219d..9be2c728 100644 --- a/helpers/training/optimizer_param.py +++ b/helpers/training/optimizer_param.py @@ -38,6 +38,7 @@ }, "adamw_schedulefree": { "precision": "any", + "override_lr_scheduler": True, "default_settings": { "betas": (0.9, 0.999), "weight_decay": 1e-2, @@ -216,6 +217,13 @@ def optimizer_parameters(optimizer, args): raise ValueError(f"Optimizer {optimizer} not found.") +def is_lr_scheduler_disabled(optimizer: str): + """Check if the optimizer has a built-in LR scheduler""" + if optimizer in optimizer_choices: + return optimizer_choices.get(optimizer).get("override_lr_scheduler", False) + return False + + def show_optimizer_defaults(optimizer: str = None): """we'll print the defaults on a single line, eg. foo=bar, buz=baz""" if optimizer is None: diff --git a/helpers/training/optimizers/adam_bfloat16/__init__.py b/helpers/training/optimizers/adamw_bfloat16/__init__.py similarity index 100% rename from helpers/training/optimizers/adam_bfloat16/__init__.py rename to helpers/training/optimizers/adamw_bfloat16/__init__.py diff --git a/helpers/training/optimizers/adam_bfloat16/stochastic/__init__.py b/helpers/training/optimizers/adamw_bfloat16/stochastic/__init__.py similarity index 100% rename from helpers/training/optimizers/adam_bfloat16/stochastic/__init__.py rename to helpers/training/optimizers/adamw_bfloat16/stochastic/__init__.py diff --git a/train.py b/train.py index e9982288..d94d284c 100644 --- a/train.py +++ b/train.py @@ -33,6 +33,8 @@ from helpers.training.validation import Validation, prepare_validation_prompt_list from helpers.training.state_tracker import StateTracker from helpers.training.schedulers import load_scheduler_from_args +from helpers.training.custom_schedule import get_lr_scheduler +from helpers.training.optimizer_param import is_lr_scheduler_disabled from helpers.training.adapter import determine_adapter_target_modules, load_lora_weights from helpers.training.diffusion_model import load_diffusion_model from helpers.training.text_encoding import ( @@ -839,9 +841,14 @@ def main(): f" {args.num_train_epochs} epochs and {num_update_steps_per_epoch} steps per epoch." ) overrode_max_train_steps = True - logger.info( - f"Loading {args.lr_scheduler} learning rate scheduler with {args.lr_warmup_steps} warmup steps" - ) + if not is_lr_scheduler_disabled(args.optimizer): + logger.info( + f"Loading {args.lr_scheduler} learning rate scheduler with {args.lr_warmup_steps} warmup steps" + ) + else: + logger.info( + "Using experimental AdamW ScheduleFree optimiser from Facebook. Experimental due to newly added Kahan summation." + ) if args.layer_freeze_strategy == "bitfit": from helpers.training.model_freeze import apply_bitfit_freezing @@ -978,8 +985,9 @@ def main(): optimizer, ) - from helpers.training.custom_schedule import get_lr_scheduler - + if is_lr_scheduler_disabled(args.optimizer): + # we don't use LR schedulers with schedulefree schedulers (lol) + lr_scheduler = None if not use_deepspeed_scheduler: logger.info( f"Loading {args.lr_scheduler} learning rate scheduler with {args.lr_warmup_steps} warmup steps" @@ -994,10 +1002,11 @@ def main(): total_num_steps=args.max_train_steps, warmup_num_steps=args.lr_warmup_steps, ) - if hasattr(lr_scheduler, "num_update_steps_per_epoch"): - lr_scheduler.num_update_steps_per_epoch = num_update_steps_per_epoch - if hasattr(lr_scheduler, "last_step"): - lr_scheduler.last_step = global_resume_step + if lr_scheduler is not None: + if hasattr(lr_scheduler, "num_update_steps_per_epoch"): + lr_scheduler.num_update_steps_per_epoch = num_update_steps_per_epoch + if hasattr(lr_scheduler, "last_step"): + lr_scheduler.last_step = global_resume_step accelerator.wait_for_everyone() @@ -1285,6 +1294,10 @@ def main(): if "sampler" in backend: backend["sampler"].log_state() + if is_lr_scheduler_disabled(args.optimizer) and hasattr(optimizer, "train"): + # we typically have to call train() on the optim for schedulefree. + optimizer.train() + total_steps_remaining_at_start = args.max_train_steps # We store the number of dataset resets that have occurred inside the checkpoint. first_epoch = StateTracker.get_epoch() @@ -1399,7 +1412,11 @@ def main(): if webhook_handler is not None: webhook_handler.send(message=initial_msg) if args.validation_on_startup and global_step <= 1: + if is_lr_scheduler_disabled(args.optimizer): + optimizer.eval() validation.run_validations(validation_type="base_model", step=0) + if is_lr_scheduler_disabled(args.optimizer): + optimizer.train() # Only show the progress bar once on each machine. show_progress_bar = True @@ -2051,8 +2068,9 @@ def main(): # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: try: - lr_scheduler.step(**scheduler_kwargs) - lr = lr_scheduler.get_last_lr()[0] + if lr_scheduler is not None: + lr_scheduler.step(**scheduler_kwargs) + lr = lr_scheduler.get_last_lr()[0] except Exception as e: logger.error( f"Failed to get the last learning rate from the scheduler. Error: {e}" @@ -2164,7 +2182,12 @@ def main(): args.output_dir, f"checkpoint-{global_step}" ) print("\n") + # schedulefree optim needs the optimizer to be in eval mode to save the state (and then back to train after) + if is_lr_scheduler_disabled(args.optimizer): + optimizer.eval() accelerator.save_state(save_path) + if is_lr_scheduler_disabled(args.optimizer): + optimizer.train() for _, backend in StateTracker.get_data_backends().items(): if "sampler" in backend: logger.debug(f"Backend: {backend}") @@ -2185,7 +2208,11 @@ def main(): "lr": lr, } progress_bar.set_postfix(**logs) + if is_lr_scheduler_disabled(args.optimizer): + optimizer.eval() validation.run_validations(validation_type="intermediary", step=step) + if is_lr_scheduler_disabled(args.optimizer): + optimizer.train() if ( args.push_to_hub and args.push_checkpoints_to_hub @@ -2220,6 +2247,8 @@ def main(): # Create the pipeline using the trained modules and save it. accelerator.wait_for_everyone() if accelerator.is_main_process: + if is_lr_scheduler_disabled(args.optimizer): + optimizer.eval() validation_images = validation.run_validations( validation_type="final", step=global_step, From d29c4807709fc4d338b61c8c69f9f0a3ed4b118d Mon Sep 17 00:00:00 2001 From: bghira Date: Fri, 23 Aug 2024 17:14:21 -0600 Subject: [PATCH 15/18] fix train.sh check for step/epoch being zero --- train.sh | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/train.sh b/train.sh index 78d00142..4eb20cd8 100755 --- a/train.sh +++ b/train.sh @@ -303,18 +303,18 @@ if ! [ -f "$DATALOADER_CONFIG" ]; then printf "DATALOADER_CONFIG file %s not found, cannot continue.\n" "${DATALOADER_CONFIG}" exit 1 fi -if [ -z "$MAX_TRAIN_STEPS" ] && [ -z "$NUM_EPOCHS" ]; then - echo "Neither MAX_TRAIN_STEPS or NUM_EPOCHS were defined." +if [ -z "$MAX_NUM_STEPS" ] && [ -z "$NUM_EPOCHS" ]; then + echo "Neither MAX_NUM_STEPS or NUM_EPOCHS were defined." exit 1 fi -if [ -z "$MAX_TRAIN_STEPS" ]; then - export MAX_TRAIN_STEPS=0 +if [ -z "$MAX_NUM_STEPS" ]; then + export MAX_NUM_STEPS=0 fi if [ -z "$NUM_EPOCHS" ]; then export NUM_EPOCHS=0 fi -if [[ "$MAX_TRAIN_STEPS" == "0" ]] && [[ "$NUM_EPOCHS" == "0" ]]; then - echo "Both MAX_TRAIN_STEPS and NUM_EPOCHS cannot be zero." +if [ "$MAX_NUM_STEPS" -lt 1 ] && [ "$NUM_EPOCHS" -lt 1 ]; then + echo "Both MAX_NUM_STEPS {$MAX_NUM_STEPS} and NUM_EPOCHS {$NUM_EPOCHS} cannot be zero." exit 1 fi export SNR_GAMMA_ARG="" @@ -367,9 +367,9 @@ if [ -n "$ASPECT_BUCKET_ROUNDING" ]; then export ASPECT_BUCKET_ROUNDING_ARGS="--aspect_bucket_rounding=${ASPECT_BUCKET_ROUNDING}" fi -export MAX_TRAIN_STEPS_ARGS="" +export MAX_NUM_STEPS_ARGS="" if [ -n "$MAX_NUM_STEPS" ] && [[ "$MAX_NUM_STEPS" != 0 ]]; then - export MAX_TRAIN_STEPS_ARGS="--max_train_steps=${MAX_NUM_STEPS}" + export MAX_NUM_STEPS_ARGS="--max_train_steps=${MAX_NUM_STEPS}" fi export CONTROLNET_ARGS="" @@ -382,7 +382,7 @@ fi accelerate launch ${ACCELERATE_EXTRA_ARGS} --mixed_precision="${MIXED_PRECISION}" --num_processes="${TRAINING_NUM_PROCESSES}" --num_machines="${TRAINING_NUM_MACHINES}" --dynamo_backend="${TRAINING_DYNAMO_BACKEND}" train.py \ --model_type="${MODEL_TYPE}" ${DORA_ARGS} --pretrained_model_name_or_path="${MODEL_NAME}" ${XFORMERS_ARG} ${GRADIENT_ARG} --set_grads_to_none --gradient_accumulation_steps=${GRADIENT_ACCUMULATION_STEPS} \ --resume_from_checkpoint="${RESUME_CHECKPOINT}" ${DELETE_ARGS} ${SNR_GAMMA_ARG} --data_backend_config="${DATALOADER_CONFIG}" \ - --num_train_epochs=${NUM_EPOCHS} ${MAX_TRAIN_STEPS_ARGS} --metadata_update_interval=${METADATA_UPDATE_INTERVAL} \ + --num_train_epochs=${NUM_EPOCHS} ${MAX_NUM_STEPS_ARGS} --metadata_update_interval=${METADATA_UPDATE_INTERVAL} \ ${OPTIMIZER_ARG} --learning_rate="${LEARNING_RATE}" --lr_scheduler="${LR_SCHEDULE}" --seed "${TRAINING_SEED}" --lr_warmup_steps="${LR_WARMUP_STEPS}" \ --output_dir="${OUTPUT_DIR}" ${BITFIT_ARGS} ${ASPECT_BUCKET_ROUNDING_ARGS} \ --inference_scheduler_timestep_spacing="${INFERENCE_SCHEDULER_TIMESTEP_SPACING}" --training_scheduler_timestep_spacing="${TRAINING_SCHEDULER_TIMESTEP_SPACING}" \ From e9758a12cad82bd93c998ad7b45b111568d4366f Mon Sep 17 00:00:00 2001 From: bghira Date: Fri, 23 Aug 2024 17:14:49 -0600 Subject: [PATCH 16/18] schedulefree: retrieve last_lr --- helpers/training/optimizers/adamw_schedulefree/__init__.py | 2 ++ train.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/helpers/training/optimizers/adamw_schedulefree/__init__.py b/helpers/training/optimizers/adamw_schedulefree/__init__.py index b58e8814..eedfbabc 100644 --- a/helpers/training/optimizers/adamw_schedulefree/__init__.py +++ b/helpers/training/optimizers/adamw_schedulefree/__init__.py @@ -38,6 +38,7 @@ def __init__( super(AdamWScheduleFreeKahan, self).__init__(params, defaults) self.k = 0 self.lr_max = -1.0 + self.last_lr = -1.0 self.weight_sum = 0.0 def _initialize_state(self, state, p): @@ -138,5 +139,6 @@ def step(self, closure=None): kahan_comp.copy_(p.data.sub(buffer).add(buffer.sub_(p.data))) self.k += 1 + self.last_lr = adjusted_lr return loss diff --git a/train.py b/train.py index d94d284c..7101bf68 100644 --- a/train.py +++ b/train.py @@ -2071,6 +2071,8 @@ def main(): if lr_scheduler is not None: lr_scheduler.step(**scheduler_kwargs) lr = lr_scheduler.get_last_lr()[0] + elif hasattr(optimizer, "last_lr"): + lr = optimizer.last_lr except Exception as e: logger.error( f"Failed to get the last learning rate from the scheduler. Error: {e}" From 4603949fd507d2b1e002714e94be55990926767e Mon Sep 17 00:00:00 2001 From: bghira Date: Fri, 23 Aug 2024 17:56:20 -0600 Subject: [PATCH 17/18] schedulefree: retrieve LR using statetracker --- helpers/training/optimizer_param.py | 15 +++++++-- .../optimizers/adamw_schedulefree/__init__.py | 5 +++ helpers/training/state_tracker.py | 11 +++++++ train.py | 31 +++++++++++-------- 4 files changed, 46 insertions(+), 16 deletions(-) diff --git a/helpers/training/optimizer_param.py b/helpers/training/optimizer_param.py index 9be2c728..40fe007f 100644 --- a/helpers/training/optimizer_param.py +++ b/helpers/training/optimizer_param.py @@ -39,6 +39,7 @@ "adamw_schedulefree": { "precision": "any", "override_lr_scheduler": True, + "can_warmup": True, "default_settings": { "betas": (0.9, 0.999), "weight_decay": 1e-2, @@ -219,9 +220,12 @@ def optimizer_parameters(optimizer, args): def is_lr_scheduler_disabled(optimizer: str): """Check if the optimizer has a built-in LR scheduler""" + is_disabled = False if optimizer in optimizer_choices: - return optimizer_choices.get(optimizer).get("override_lr_scheduler", False) - return False + is_disabled = optimizer_choices.get(optimizer).get( + "override_lr_scheduler", False + ) + return is_disabled def show_optimizer_defaults(optimizer: str = None): @@ -278,7 +282,12 @@ def determine_optimizer_class_with_config( else: optimizer_class, optimizer_details = optimizer_parameters(args.optimizer, args) default_settings = optimizer_details.get("default_settings") - logger.info(f"cls: {optimizer_class}, settings: {default_settings}") + if optimizer_details.get("can_warmup", False): + logger.info( + f"Optimizer contains LR scheduler, warmup steps will be set to {args.lr_warmup_steps}." + ) + default_settings["warmup_steps"] = args.lr_warmup_steps + logger.info(f"cls: {optimizer_class}, settings: {default_settings}") return default_settings, optimizer_class diff --git a/helpers/training/optimizers/adamw_schedulefree/__init__.py b/helpers/training/optimizers/adamw_schedulefree/__init__.py index eedfbabc..55de9a57 100644 --- a/helpers/training/optimizers/adamw_schedulefree/__init__.py +++ b/helpers/training/optimizers/adamw_schedulefree/__init__.py @@ -2,6 +2,7 @@ from torch.optim.optimizer import Optimizer import math from typing import Iterable +from helpers.training.state_tracker import StateTracker class AdamWScheduleFreeKahan(Optimizer): @@ -140,5 +141,9 @@ def step(self, closure=None): self.k += 1 self.last_lr = adjusted_lr + StateTracker.set_last_lr(adjusted_lr) return loss + + def get_last_lr(self): + return self.last_lr diff --git a/helpers/training/state_tracker.py b/helpers/training/state_tracker.py index d637700e..c636a358 100644 --- a/helpers/training/state_tracker.py +++ b/helpers/training/state_tracker.py @@ -49,6 +49,9 @@ class StateTracker: # Aspect to resolution map, we'll store once generated for consistency. aspect_resolution_map = {} + # for schedulefree + last_lr = 0.0 + # hugging face hub user details hf_user = None @@ -530,3 +533,11 @@ def load_aspect_resolution_map(cls, dataloader_resolution: float): logger.debug( f"Aspect resolution map: {cls.aspect_resolution_map[dataloader_resolution]}" ) + + @classmethod + def get_last_lr(cls): + return cls.last_lr + + @classmethod + def set_last_lr(cls, last_lr: float): + cls.last_lr = float(last_lr) diff --git a/train.py b/train.py index 7101bf68..f876d1d0 100644 --- a/train.py +++ b/train.py @@ -841,13 +841,14 @@ def main(): f" {args.num_train_epochs} epochs and {num_update_steps_per_epoch} steps per epoch." ) overrode_max_train_steps = True - if not is_lr_scheduler_disabled(args.optimizer): + is_schedulefree = is_lr_scheduler_disabled(args.optimizer) + if is_schedulefree: logger.info( - f"Loading {args.lr_scheduler} learning rate scheduler with {args.lr_warmup_steps} warmup steps" + "Using experimental AdamW ScheduleFree optimiser from Facebook. Experimental due to newly added Kahan summation." ) else: logger.info( - "Using experimental AdamW ScheduleFree optimiser from Facebook. Experimental due to newly added Kahan summation." + f"Loading {args.lr_scheduler} learning rate scheduler with {args.lr_warmup_steps} warmup steps" ) if args.layer_freeze_strategy == "bitfit": from helpers.training.model_freeze import apply_bitfit_freezing @@ -988,7 +989,7 @@ def main(): if is_lr_scheduler_disabled(args.optimizer): # we don't use LR schedulers with schedulefree schedulers (lol) lr_scheduler = None - if not use_deepspeed_scheduler: + if not use_deepspeed_scheduler and not is_schedulefree: logger.info( f"Loading {args.lr_scheduler} learning rate scheduler with {args.lr_warmup_steps} warmup steps" ) @@ -996,12 +997,15 @@ def main(): args, optimizer, accelerator, logger, use_deepspeed_scheduler=False ) else: - logger.info(f"Using DeepSpeed learning rate scheduler") - lr_scheduler = accelerate.utils.DummyScheduler( - optimizer, - total_num_steps=args.max_train_steps, - warmup_num_steps=args.lr_warmup_steps, - ) + logger.info(f"Using dummy learning rate scheduler") + if torch.backends.mps.is_available(): + lr_scheduler = None + else: + lr_scheduler = accelerate.utils.DummyScheduler( + optimizer, + total_num_steps=args.max_train_steps, + warmup_num_steps=args.lr_warmup_steps, + ) if lr_scheduler is not None: if hasattr(lr_scheduler, "num_update_steps_per_epoch"): lr_scheduler.num_update_steps_per_epoch = num_update_steps_per_epoch @@ -2068,11 +2072,12 @@ def main(): # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: try: - if lr_scheduler is not None: + if is_schedulefree: + # hackjob method of retrieving LR from accelerated optims + lr = StateTracker.get_last_lr() + else: lr_scheduler.step(**scheduler_kwargs) lr = lr_scheduler.get_last_lr()[0] - elif hasattr(optimizer, "last_lr"): - lr = optimizer.last_lr except Exception as e: logger.error( f"Failed to get the last learning rate from the scheduler. Error: {e}" From 9c4736ba2a7f464c6fd817af14f9350142047af4 Mon Sep 17 00:00:00 2001 From: bghira Date: Fri, 23 Aug 2024 17:59:11 -0600 Subject: [PATCH 18/18] clean-up repeating calls --- train.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/train.py b/train.py index f876d1d0..b4e64c3c 100644 --- a/train.py +++ b/train.py @@ -2190,10 +2190,10 @@ def main(): ) print("\n") # schedulefree optim needs the optimizer to be in eval mode to save the state (and then back to train after) - if is_lr_scheduler_disabled(args.optimizer): + if is_schedulefree: optimizer.eval() accelerator.save_state(save_path) - if is_lr_scheduler_disabled(args.optimizer): + if is_schedulefree: optimizer.train() for _, backend in StateTracker.get_data_backends().items(): if "sampler" in backend: @@ -2215,10 +2215,10 @@ def main(): "lr": lr, } progress_bar.set_postfix(**logs) - if is_lr_scheduler_disabled(args.optimizer): + if is_schedulefree: optimizer.eval() validation.run_validations(validation_type="intermediary", step=step) - if is_lr_scheduler_disabled(args.optimizer): + if is_schedulefree: optimizer.train() if ( args.push_to_hub @@ -2254,7 +2254,7 @@ def main(): # Create the pipeline using the trained modules and save it. accelerator.wait_for_everyone() if accelerator.is_main_process: - if is_lr_scheduler_disabled(args.optimizer): + if is_schedulefree: optimizer.eval() validation_images = validation.run_validations( validation_type="final",