Skip to content

Commit

Permalink
fix to call train/eval in schedulefree #1605
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Sep 18, 2024
1 parent e745021 commit 1286e00
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 1 deletion.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ The command to install PyTorch is as follows:

### Recent Updates

Sep 18, 2024 (update 1):
Fixed an issue where train()/eval() was not called properly with the schedule-free optimizer. The schedule-free optimizer can be used in FLUX.1 LoRA training and fine-tuning for now.

Sep 18, 2024:

- Schedule-free optimizer is added. Thanks to sdbds! See PR [#1600](https://github.com/kohya-ss/sd-scripts/pull/1600) for details.
Expand Down
10 changes: 10 additions & 0 deletions flux_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,8 +347,13 @@ def train(args):

logger.info(f"using {len(optimizers)} optimizers for blockwise fused optimizers")

if train_util.is_schedulefree_optimizer(optimizers[0], args):
raise ValueError("Schedule-free optimizer is not supported with blockwise fused optimizers")
optimizer_train_fn = lambda: None # dummy function
optimizer_eval_fn = lambda: None # dummy function
else:
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args)

# prepare dataloader
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
Expand Down Expand Up @@ -760,6 +765,7 @@ def optimizer_hook(parameter: torch.Tensor):
progress_bar.update(1)
global_step += 1

optimizer_eval_fn()
flux_train_utils.sample_images(
accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs
)
Expand All @@ -778,6 +784,7 @@ def optimizer_hook(parameter: torch.Tensor):
global_step,
accelerator.unwrap_model(flux),
)
optimizer_train_fn()

current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
if len(accelerator.trackers) > 0:
Expand All @@ -800,6 +807,7 @@ def optimizer_hook(parameter: torch.Tensor):

accelerator.wait_for_everyone()

optimizer_eval_fn()
if args.save_every_n_epochs is not None:
if accelerator.is_main_process:
flux_train_utils.save_flux_model_on_epoch_end_or_stepwise(
Expand All @@ -816,12 +824,14 @@ def optimizer_hook(parameter: torch.Tensor):
flux_train_utils.sample_images(
accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs
)
optimizer_train_fn()

is_main_process = accelerator.is_main_process
# if is_main_process:
flux = accelerator.unwrap_model(flux)

accelerator.end_training()
optimizer_eval_fn()

if args.save_state or args.save_state_on_train_end:
train_util.save_state_on_train_end(args, accelerator)
Expand Down
15 changes: 14 additions & 1 deletion library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import time
from typing import (
Any,
Callable,
Dict,
List,
NamedTuple,
Expand Down Expand Up @@ -4715,8 +4716,20 @@ def __instancecheck__(self, instance):
return optimizer_name, optimizer_args, optimizer


def get_optimizer_train_eval_fn(optimizer: Optimizer, args: argparse.Namespace) -> Tuple[Callable, Callable]:
if not is_schedulefree_optimizer(optimizer, args):
# return dummy func
return lambda: None, lambda: None

# get train and eval functions from optimizer
train_fn = optimizer.train
eval_fn = optimizer.eval

return train_fn, eval_fn


def is_schedulefree_optimizer(optimizer: Optimizer, args: argparse.Namespace) -> bool:
return args.optimizer_type.lower().endswith("schedulefree".lower()) # or args.optimizer_schedulefree_wrapper
return args.optimizer_type.lower().endswith("schedulefree".lower()) # or args.optimizer_schedulefree_wrapper


def get_dummy_scheduler(optimizer: Optimizer) -> Any:
Expand Down
6 changes: 6 additions & 0 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,7 @@ def train(self, args):
# accelerator.print(f"trainable_params: {k} = {v}")

optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args)

# prepare dataloader
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
Expand Down Expand Up @@ -1199,6 +1200,7 @@ def remove_model(old_ckpt_name):
progress_bar.update(1)
global_step += 1

optimizer_eval_fn()
self.sample_images(
accelerator, args, None, global_step, accelerator.device, vae, tokenizers, text_encoder, unet
)
Expand All @@ -1217,6 +1219,7 @@ def remove_model(old_ckpt_name):
if remove_step_no is not None:
remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no)
remove_model(remove_ckpt_name)
optimizer_train_fn()

current_loss = loss.detach().item()
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
Expand All @@ -1243,6 +1246,7 @@ def remove_model(old_ckpt_name):
accelerator.wait_for_everyone()

# 指定エポックごとにモデルを保存
optimizer_eval_fn()
if args.save_every_n_epochs is not None:
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
if is_main_process and saving:
Expand All @@ -1258,6 +1262,7 @@ def remove_model(old_ckpt_name):
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)

self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizers, text_encoder, unet)
optimizer_train_fn()

# end of epoch

Expand All @@ -1268,6 +1273,7 @@ def remove_model(old_ckpt_name):
network = accelerator.unwrap_model(network)

accelerator.end_training()
optimizer_eval_fn()

if is_main_process and (args.save_state or args.save_state_on_train_end):
train_util.save_state_on_train_end(args, accelerator)
Expand Down

0 comments on commit 1286e00

Please sign in to comment.