Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for more Dadaptation #455

Merged
merged 14 commits into from
May 6, 2023
2 changes: 1 addition & 1 deletion fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
if args.optimizer_type.lower() == "DAdaptation".lower() or args.optimizer_type.lower() == "DAdaptAdam".lower() or args.optimizer_type.lower() == "DAdaptAdaGrad".lower() or args.optimizer_type.lower() == "DAdaptAdan".lower() or args.optimizer_type.lower() == "DAdaptSGD".lower(): # tracking d*lr value
logs["lr/d*lr"] = (
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
)
Expand Down
93 changes: 90 additions & 3 deletions library/train_util.py
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_optimizer method looks like it would be better to refactor to make some parts common. If it is difficult to deal with, I will update it after the merge.

Original file line number Diff line number Diff line change
Expand Up @@ -1885,7 +1885,7 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
"--optimizer_type",
type=str,
default="",
help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, Lion, Lion8bit,SGDNesterov, SGDNesterov8bit, DAdaptation, AdaFactor",
help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdam), DAdaptAdaGrad, DAdaptAdan, DAdaptSGD, AdaFactor",
)

# backward compatibility
Expand Down Expand Up @@ -2467,7 +2467,7 @@ def task():


def get_optimizer(args, trainable_params):
# "Optimizer to use: AdamW, AdamW8bit, Lion, Lion8bit, SGDNesterov, SGDNesterov8bit, DAdaptation, Adafactor"
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, DAdaptation(DAdaptAdam), DAdaptAdaGrad, DAdaptAdan, DAdaptSGD, Adafactor"

optimizer_type = args.optimizer_type
if args.use_8bit_adam:
Expand Down Expand Up @@ -2570,7 +2570,7 @@ def get_optimizer(args, trainable_params):
optimizer_class = torch.optim.SGD
optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs)

elif optimizer_type == "DAdaptation".lower():
elif optimizer_type == "DAdaptation".lower() or optimizer_type == "DAdaptAdam".lower():
try:
import dadaptation
except ImportError:
Expand Down Expand Up @@ -2598,7 +2598,94 @@ def get_optimizer(args, trainable_params):

optimizer_class = dadaptation.DAdaptAdam
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)

elif optimizer_type == "DAdaptAdaGrad".lower():
try:
import dadaptation
except ImportError:
raise ImportError("No dadaptation / dadaptation がインストールされていないようです")
print(f"use D-Adaptation AdaGrad optimizer | {optimizer_kwargs}")

actual_lr = lr
lr_count = 1
if type(trainable_params) == list and type(trainable_params[0]) == dict:
lrs = set()
actual_lr = trainable_params[0].get("lr", actual_lr)
for group in trainable_params:
lrs.add(group.get("lr", actual_lr))
lr_count = len(lrs)

if actual_lr <= 0.1:
print(
f"learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: lr={actual_lr}"
)
print("recommend option: lr=1.0 / 推奨は1.0です")
if lr_count > 1:
print(
f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-Adaptationで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}"
)

optimizer_class = dadaptation.DAdaptAdaGrad
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)

elif optimizer_type == "DAdaptAdan".lower():
try:
import dadaptation
except ImportError:
raise ImportError("No dadaptation / dadaptation がインストールされていないようです")
print(f"use D-Adaptation Adan optimizer | {optimizer_kwargs}")

actual_lr = lr
lr_count = 1
if type(trainable_params) == list and type(trainable_params[0]) == dict:
lrs = set()
actual_lr = trainable_params[0].get("lr", actual_lr)
for group in trainable_params:
lrs.add(group.get("lr", actual_lr))
lr_count = len(lrs)

if actual_lr <= 0.1:
print(
f"learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: lr={actual_lr}"
)
print("recommend option: lr=1.0 / 推奨は1.0です")
if lr_count > 1:
print(
f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-Adaptationで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}"
)

optimizer_class = dadaptation.DAdaptAdan
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)

elif optimizer_type == "DAdaptSGD".lower():
try:
import dadaptation
except ImportError:
raise ImportError("No dadaptation / dadaptation がインストールされていないようです")
print(f"use D-Adaptation SGD optimizer | {optimizer_kwargs}")

actual_lr = lr
lr_count = 1
if type(trainable_params) == list and type(trainable_params[0]) == dict:
lrs = set()
actual_lr = trainable_params[0].get("lr", actual_lr)
for group in trainable_params:
lrs.add(group.get("lr", actual_lr))
lr_count = len(lrs)

if actual_lr <= 0.1:
print(
f"learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: lr={actual_lr}"
)
print("recommend option: lr=1.0 / 推奨は1.0です")
if lr_count > 1:
print(
f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-Adaptationで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}"
)

optimizer_class = dadaptation.DAdaptSGD
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)

elif optimizer_type == "Adafactor".lower():
# 引数を確認して適宜補正する
if "relative_step" not in optimizer_kwargs:
Expand Down
5 changes: 4 additions & 1 deletion train_README-ja.md
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,10 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
- Lion8bit : 引数は同上
- SGDNesterov : [torch.optim.SGD](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html), nesterov=True
- SGDNesterov8bit : 引数は同上
- DAdaptation : https://github.com/facebookresearch/dadaptation
- DAdaptation(DAdaptAdam) : https://github.com/facebookresearch/dadaptation
- DAdaptAdaGrad : 引数は同上
- DAdaptAdan : 引数は同上
- DAdaptSGD : 引数は同上
- AdaFactor : [Transformers AdaFactor](https://huggingface.co/docs/transformers/main_classes/optimizer_schedules)
- 任意のオプティマイザ

Expand Down
2 changes: 1 addition & 1 deletion train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ def train(args):
current_loss = loss.detach().item()
if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
if args.optimizer_type.lower() == "DAdaptation".lower() or args.optimizer_type.lower() == "DAdaptAdam".lower() or args.optimizer_type.lower() == "DAdaptAdaGrad".lower() or args.optimizer_type.lower() == "DAdaptAdan".lower() or args.optimizer_type.lower() == "DAdaptSGD".lower(): # tracking d*lr value
logs["lr/d*lr"] = (
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
)
Expand Down
4 changes: 2 additions & 2 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_sche
logs["lr/textencoder"] = float(lrs[0])
logs["lr/unet"] = float(lrs[-1]) # may be same to textencoder

if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value of unet.
if args.optimizer_type.lower() == "DAdaptation".lower() or args.optimizer_type.lower() == "DAdaptAdam".lower() or args.optimizer_type.lower() == "DAdaptAdaGrad".lower() or args.optimizer_type.lower() == "DAdaptAdan".lower() or args.optimizer_type.lower() == "DAdaptSGD".lower(): # tracking d*lr value of unet.
logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"]
else:
idx = 0
Expand All @@ -53,7 +53,7 @@ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_sche

for i in range(idx, len(lrs)):
logs[f"lr/group{i}"] = float(lrs[i])
if args.optimizer_type.lower() == "DAdaptation".lower():
if args.optimizer_type.lower() == "DAdaptation".lower() or args.optimizer_type.lower() == "DAdaptAdam".lower() or args.optimizer_type.lower() == "DAdaptAdaGrad".lower() or args.optimizer_type.lower() == "DAdaptAdan".lower() or args.optimizer_type.lower() == "DAdaptSGD".lower():
logs[f"lr/d*lr/group{i}"] = (
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
)
Expand Down
2 changes: 1 addition & 1 deletion train_textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ def remove_model(old_ckpt_name):
current_loss = loss.detach().item()
if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
if args.optimizer_type.lower() == "DAdaptation".lower() or args.optimizer_type.lower() == "DAdaptAdam".lower() or args.optimizer_type.lower() == "DAdaptAdaGrad".lower() or args.optimizer_type.lower() == "DAdaptAdan".lower() or args.optimizer_type.lower() == "DAdaptSGD".lower(): # tracking d*lr value
logs["lr/d*lr"] = (
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
)
Expand Down
2 changes: 1 addition & 1 deletion train_textual_inversion_XTI.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ def remove_model(old_ckpt_name):
current_loss = loss.detach().item()
if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
if args.optimizer_type.lower() == "DAdaptation".lower() or args.optimizer_type.lower() == "DAdaptAdam".lower() or args.optimizer_type.lower() == "DAdaptAdaGrad".lower() or args.optimizer_type.lower() == "DAdaptAdan".lower() or args.optimizer_type.lower() == "DAdaptSGD".lower(): # tracking d*lr value
logs["lr/d*lr"] = (
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
)
Expand Down