Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

merge #858

Merged
merged 22 commits into from
Aug 24, 2024
Merged

merge #858

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
abe647f
Update and rename csv.py to csv_.py
williamzhuk Aug 5, 2024
6f91502
remove some parts that are non-universal
williamzhuk Aug 5, 2024
52b6554
Update factory.py
williamzhuk Aug 5, 2024
c693c7d
Update factory.py
williamzhuk Aug 5, 2024
ff0739c
Update factory.py
williamzhuk Aug 5, 2024
614764d
Update helpers/data_backend/csv_.py
williamzhuk Aug 19, 2024
9b833d2
backwards-compat breaking: set cache_file_suffix by default to the id…
Aug 23, 2024
c464201
parquet: set width and height column value by default to width and he…
Aug 23, 2024
8bb8481
global shame: hide the existence of gradient accumulation steps
Aug 23, 2024
71b33ac
Merge branch 'release' of ssh://github.com/bghira/SimpleTuner
Aug 23, 2024
b549f84
Rename csv_.py to csv_url_list.py
bghira Aug 23, 2024
48426a4
Update factory.py
bghira Aug 23, 2024
eb92763
Merge branch 'patch-1' of https://github.com/williamzhuk/simpletuner …
Aug 23, 2024
5e5df13
Merge pull request #856 from bghira/fix/csv
bghira Aug 23, 2024
177190d
trainer: do not override NUM_EPOCHS and MAX_NUM_STEPS in config.env.e…
Aug 23, 2024
7c1f047
new optimiser, adamw_schedulefree
Aug 23, 2024
0cc0440
schedulefree integration
Aug 23, 2024
d29c480
fix train.sh check for step/epoch being zero
Aug 23, 2024
e9758a1
schedulefree: retrieve last_lr
Aug 23, 2024
4603949
schedulefree: retrieve LR using statetracker
Aug 23, 2024
9c4736b
clean-up repeating calls
Aug 23, 2024
fd78777
Merge pull request #857 from bghira/feature/optim-schedulefree
bghira Aug 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions config/config.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,10 @@ export TRACKER_PROJECT_NAME="${MODEL_TYPE}-training"
export TRACKER_RUN_NAME="simpletuner-sdxl"

# Max number of steps OR epochs can be used. Not both.
export MAX_NUM_STEPS=30000
# You MUST uncomment and set one of these.
#export MAX_NUM_STEPS=30000
# Will likely overtrain, but that's fine.
export NUM_EPOCHS=0
#export NUM_EPOCHS=0

# A convenient prefix for all of your training paths.
# These may be absolute or relative paths. Here, we are using relative paths.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import fnmatch
import io
from datetime import datetime
from urllib.request import url2pathname
import hashlib

import pandas as pd
import requests
from PIL import Image

from helpers.data_backend.base import BaseDataBackend
from helpers.image_manipulation.load import load_image
Expand All @@ -28,23 +25,21 @@ def url_to_filename(url: str) -> str:
return url.split("/")[-1]


def shorten_and_clean_filename(filename: str, no_op: bool):
if no_op:
return filename
filename = filename.replace("%20", "-").replace(" ", "-")
if len(filename) > 250:
filename = filename[:120] + "---" + filename[126:]
return filename
def str_hash(filename: str) -> str:
return str(hashlib.sha256(str(filename).encode()).hexdigest())


def html_to_file_loc(parent_directory: Path, url: str, shorten_filenames: bool) -> str:
def path_to_hashed_path(path: Path, hash_filenames: bool) -> Path:
path = Path(path).resolve()
if hash_filenames:
return path.parent.joinpath(str_hash(path.stem) + path.suffix)
return path


def html_to_file_loc(parent_directory: Path, url: str, hash_filenames: bool) -> str:
filename = url_to_filename(url)
cached_loc = str(
parent_directory.joinpath(
shorten_and_clean_filename(filename, no_op=shorten_filenames)
)
)
return cached_loc
cached_loc = path_to_hashed_path(parent_directory.joinpath(filename), hash_filenames)
return str(cached_loc.resolve())


class CSVDataBackend(BaseDataBackend):
Expand All @@ -54,29 +49,28 @@ def __init__(
id: str,
csv_file: Path,
compress_cache: bool = False,
image_url_col: str = "url",
caption_column: str = "caption",
url_column: str = "url",
caption_column: str = "caption",
image_cache_loc: Optional[str] = None,
shorten_filenames: bool = False,
hash_filenames: bool = True,
):
self.id = id
self.type = "csv"
self.compress_cache = compress_cache
self.shorten_filenames = shorten_filenames
self.hash_filenames = hash_filenames
self.csv_file = csv_file
self.accelerator = accelerator
self.image_url_col = image_url_col
self.df = pd.read_csv(csv_file, index_col=image_url_col)
self.url_column = url_column
self.df = pd.read_csv(csv_file, index_col=url_column)
self.df = self.df.groupby(level=0).last() # deduplicate by index (image loc)
self.caption_column = caption_column
self.url_column = url_column
self.image_cache_loc = (
Path(image_cache_loc) if image_cache_loc is not None else None
)

def read(self, location, as_byteIO: bool = False):
"""Read and return the content of the file."""
already_hashed = False
if isinstance(location, Path):
location = str(location.resolve())
if location.startswith("http"):
Expand All @@ -85,11 +79,12 @@ def read(self, location, as_byteIO: bool = False):
cached_loc = html_to_file_loc(
self.image_cache_loc,
location,
shorten_filenames=self.shorten_filenames,
self.hash_filenames,
)
if os.path.exists(cached_loc):
# found cache
location = cached_loc
already_hashed = True
else:
# actually go to website
data = requests.get(location, stream=True).raw.data
Expand All @@ -99,8 +94,13 @@ def read(self, location, as_byteIO: bool = False):
data = requests.get(location, stream=True).raw.data
if not location.startswith("http"):
# read from local file
with open(location, "rb") as file:
data = file.read()
hashed_location = path_to_hashed_path(location, hash_filenames=self.hash_filenames and not already_hashed)
try:
with open(hashed_location, "rb") as file:
data = file.read()
except FileNotFoundError as e:
tqdm.write(f'ask was for file {location} bound to {hashed_location}')
raise e
if not as_byteIO:
return data
return BytesIO(data)
Expand All @@ -114,9 +114,7 @@ def write(self, filepath: Union[str, Path], data: Any) -> None:
filepath = Path(filepath)
# Not a huge fan of auto-shortening filenames, as we hash things for that in other cases.
# However, this is copied in from the original Arcade-AI CSV backend implementation for compatibility.
filepath = filepath.parent.joinpath(
shorten_and_clean_filename(filepath.name, no_op=self.shorten_filenames)
)
filepath = path_to_hashed_path(filepath, self.hash_filenames)
filepath.parent.mkdir(parents=True, exist_ok=True)
with open(filepath, "wb") as file:
# Check if data is a Tensor, and if so, save it appropriately
Expand All @@ -137,6 +135,7 @@ def delete(self, filepath):
if filepath in self.df.index:
self.df.drop(filepath, inplace=True)
# self.save_state()
filepath = path_to_hashed_path(filepath, self.hash_filenames)
if os.path.exists(filepath):
logger.debug(f"Deleting file: {filepath}")
os.remove(filepath)
Expand All @@ -146,13 +145,15 @@ def delete(self, filepath):

def exists(self, filepath):
"""Check if the file exists."""
if isinstance(filepath, Path):
filepath = str(filepath.resolve())
return filepath in self.df.index or os.path.exists(filepath)
if isinstance(filepath, str) and "http" in filepath:
return filepath in self.df.index
else:
filepath = path_to_hashed_path(filepath, self.hash_filenames)
return os.path.exists(filepath)

def open_file(self, filepath, mode):
"""Open the file in the specified mode."""
return open(filepath, mode)
return open(path_to_hashed_path(filepath, self.hash_filenames), mode)

def list_files(
self, file_extensions: list = None, instance_data_dir: str = None
Expand Down Expand Up @@ -291,9 +292,9 @@ def torch_save(self, data, location: Union[str, Path, BytesIO]):
"""
Save a torch tensor to a file.
"""

if isinstance(location, str) or isinstance(location, Path):
if location not in self.df.index:
self.df.loc[location] = pd.Series()
location = path_to_hashed_path(location, self.hash_filenames)
location = self.open_file(location, "wb")

if self.compress_cache:
Expand All @@ -309,7 +310,7 @@ def write_batch(self, filepaths: list, data_list: list) -> None:
self.write(filepath, data)

def save_state(self):
self.df.to_csv(self.csv_file, index_label=self.image_url_col)
self.df.to_csv(self.csv_file, index_label=self.url_column)

def get_caption(self, image_path: str) -> str:
if self.caption_column is None:
Expand Down
20 changes: 14 additions & 6 deletions helpers/data_backend/factory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from helpers.data_backend.local import LocalDataBackend
from helpers.data_backend.aws import S3DataBackend
from helpers.data_backend.csv import CSVDataBackend
from helpers.data_backend.csv_url_list import CSVDataBackend
from helpers.data_backend.base import BaseDataBackend
from helpers.training.default_settings import default, latest_config_version
from helpers.caching.text_embeds import TextEmbeddingCache
Expand Down Expand Up @@ -100,6 +100,8 @@ def init_backend_config(backend: dict, args: dict, accelerator) -> dict:
output["config"]["csv_file"] = backend["csv_file"]
if "csv_caption_column" in backend:
output["config"]["csv_caption_column"] = backend["csv_caption_column"]
if "csv_url_column" in backend:
output["config"]["csv_url_column"] = backend["csv_url_column"]
if "crop_aspect" in backend:
choices = ["square", "preserve", "random", "closest"]
if backend.get("crop_aspect", None) not in choices:
Expand Down Expand Up @@ -148,8 +150,8 @@ def init_backend_config(backend: dict, args: dict, accelerator) -> dict:
)
if "hash_filenames" in backend:
output["config"]["hash_filenames"] = backend["hash_filenames"]
if "shorten_filenames" in backend and backend.get("type") == "csv":
output["config"]["shorten_filenames"] = backend["shorten_filenames"]
if "hash_filenames" in backend and backend.get("type") == "csv":
output["config"]["hash_filenames"] = backend["hash_filenames"]

# check if caption_strategy=parquet with metadata_backend=json
if (
Expand Down Expand Up @@ -635,7 +637,7 @@ def configure_multi_databackend(
csv_file=backend["csv_file"],
csv_cache_dir=backend["csv_cache_dir"],
compress_cache=args.compress_disk_cache,
shorten_filenames=backend.get("shorten_filenames", False),
hash_filenames=backend.get("hash_filenames", False),
)
# init_backend["instance_data_dir"] = backend.get("instance_data_dir", backend.get("instance_data_root", backend.get("csv_cache_dir")))
init_backend["instance_data_dir"] = None
Expand Down Expand Up @@ -721,7 +723,7 @@ def configure_multi_databackend(
delete_unwanted_images=backend.get(
"delete_unwanted_images", args.delete_unwanted_images
),
cache_file_suffix=backend.get("cache_file_suffix", None),
cache_file_suffix=backend.get("cache_file_suffix", init_backend["id"]),
repeats=init_backend["config"].get("repeats", 0),
**metadata_backend_args,
)
Expand Down Expand Up @@ -1092,8 +1094,10 @@ def get_csv_backend(
id: str,
csv_file: str,
csv_cache_dir: str,
url_column: str,
caption_column: str,
compress_cache: bool = False,
shorten_filenames: bool = False,
hash_filenames: bool = False,
) -> CSVDataBackend:
from pathlib import Path

Expand All @@ -1102,8 +1106,11 @@ def get_csv_backend(
id=id,
csv_file=Path(csv_file),
image_cache_loc=csv_cache_dir,
url_column=url_column,
caption_column=caption_column,
compress_cache=compress_cache,
shorten_filenames=shorten_filenames,
hash_filenames=hash_filenames,
)


Expand All @@ -1112,6 +1119,7 @@ def check_csv_config(backend: dict, args) -> None:
"csv_file": "This is the path to the CSV file containing your image URLs.",
"csv_cache_dir": "This is the path to your temporary cache files where images will be stored. This can grow quite large.",
"csv_caption_column": "This is the column in your csv which contains the caption(s) for the samples.",
"csv_url_column": "This is the column in your csv that contains image urls or paths.",
}
for key in required_keys.keys():
if key not in backend:
Expand Down
4 changes: 2 additions & 2 deletions helpers/metadata/backends/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,8 +467,8 @@ def _process_for_bucket(
statistics["skipped"]["metadata_missing"] += 1
return aspect_ratio_bucket_indices

width_column = self.parquet_config.get("width_column")
height_column = self.parquet_config.get("height_column")
width_column = self.parquet_config.get("width_column", "width")
height_column = self.parquet_config.get("height_column", "height")
if width_column is None or height_column is None:
raise ValueError(
"ParquetMetadataBackend requires width and height columns to be defined."
Expand Down
35 changes: 31 additions & 4 deletions helpers/training/optimizer_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
logger.setLevel(target_level)

is_optimi_available = False
from helpers.training.adam_bfloat16 import AdamWBF16
from helpers.training.optimizers.adamw_bfloat16 import AdamWBF16
from helpers.training.optimizers.adamw_schedulefree import AdamWScheduleFreeKahan

try:
from optimum.quanto import QTensor
Expand All @@ -35,6 +36,17 @@
},
"class": AdamWBF16,
},
"adamw_schedulefree": {
"precision": "any",
"override_lr_scheduler": True,
"can_warmup": True,
"default_settings": {
"betas": (0.9, 0.999),
"weight_decay": 1e-2,
"eps": 1e-8,
},
"class": AdamWScheduleFreeKahan,
},
"optimi-stableadamw": {
"precision": "any",
"default_settings": {
Expand Down Expand Up @@ -154,8 +166,8 @@
}

deprecated_optimizers = {
"prodigy": "Prodigy optimiser has been removed due to issues with precision levels and convergence. Please use optimi-stableadamw or optimi-lion instead - for decoupled learning rate, use --optimizer_config=decoupled_lr=True.",
"dadaptation": "D-adaptation optimiser has been removed due to issues with precision levels and convergence. Please use optimi-stableadamw instead.",
"prodigy": "Prodigy optimiser has been removed due to issues with precision levels and convergence. Please use adamw_schedulefree instead.",
"dadaptation": "D-adaptation optimiser has been removed due to issues with precision levels and convergence. Please use adamw_schedulefree instead.",
"adafactor": "Adafactor optimiser has been removed in favour of optimi-stableadamw, which offers improved memory efficiency and convergence.",
"adamw8bit": "AdamW8Bit has been removed in favour of optimi-adamw optimiser, which offers better low-precision support. Please use this or adamw_bf16 instead.",
}
Expand Down Expand Up @@ -206,6 +218,16 @@ def optimizer_parameters(optimizer, args):
raise ValueError(f"Optimizer {optimizer} not found.")


def is_lr_scheduler_disabled(optimizer: str):
"""Check if the optimizer has a built-in LR scheduler"""
is_disabled = False
if optimizer in optimizer_choices:
is_disabled = optimizer_choices.get(optimizer).get(
"override_lr_scheduler", False
)
return is_disabled


def show_optimizer_defaults(optimizer: str = None):
"""we'll print the defaults on a single line, eg. foo=bar, buz=baz"""
if optimizer is None:
Expand Down Expand Up @@ -260,7 +282,12 @@ def determine_optimizer_class_with_config(
else:
optimizer_class, optimizer_details = optimizer_parameters(args.optimizer, args)
default_settings = optimizer_details.get("default_settings")
logger.info(f"cls: {optimizer_class}, settings: {default_settings}")
if optimizer_details.get("can_warmup", False):
logger.info(
f"Optimizer contains LR scheduler, warmup steps will be set to {args.lr_warmup_steps}."
)
default_settings["warmup_steps"] = args.lr_warmup_steps
logger.info(f"cls: {optimizer_class}, settings: {default_settings}")
return default_settings, optimizer_class


Expand Down
Loading
Loading