Skip to content

Commit

Permalink
Merge pull request #1006 from bghira/feature/soapy-shampoops
Browse files Browse the repository at this point in the history
SOAP optimiser
  • Loading branch information
bghira authored Sep 29, 2024
2 parents 2868297 + e534c45 commit 657a00c
Show file tree
Hide file tree
Showing 8 changed files with 547 additions and 33 deletions.
4 changes: 2 additions & 2 deletions configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from helpers.training.optimizer_param import optimizer_choices

bf16_only_optims = [
key for key, value in optimizer_choices.items() if value["precision"] == "bf16"
key for key, value in optimizer_choices.items() if value.get("precision", "any") == "bf16"
]
any_precision_optims = [
key for key, value in optimizer_choices.items() if value["precision"] == "any"
key for key, value in optimizer_choices.items() if value.get("precision", "any") == "any"
]
model_classes = {
"full": [
Expand Down
26 changes: 7 additions & 19 deletions helpers/configuration/cmd_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1981,12 +1981,6 @@ def parse_cmdline_args(input_args=None):
f"When using --resolution_type=pixel, --target_downsample_size must be at least 512 pixels. You may have accidentally entered {args.target_downsample_size} megapixels, instead of pixels."
)

if "int4" in args.base_model_precision and torch.cuda.is_available():
print_on_main_thread(
"WARNING: int4 precision is ONLY supported on A100 and H100 or newer devices. Waiting 10 seconds to continue.."
)
time.sleep(10)

model_is_bf16 = (
args.base_model_precision == "no_change"
and (args.mixed_precision == "bf16" or torch.backends.mps.is_available())
Expand All @@ -2009,19 +2003,13 @@ def parse_cmdline_args(input_args=None):
f"Model is not using bf16 precision, but the optimizer {chosen_optimizer} requires it."
)
if is_optimizer_grad_fp32(args.optimizer):
print(
"[WARNING] Using a low-precision optimizer that requires fp32 gradients. Training will run more slowly."
warning_log(
"Using an optimizer that requires fp32 gradients. Training will potentially run more slowly."
)
if args.gradient_precision != "fp32":
print(
f"[WARNING] Overriding gradient_precision to 'fp32' for {args.optimizer} optimizer."
)
args.gradient_precision = "fp32"
else:
if args.gradient_precision == "fp32":
print(
f"[WARNING] Overriding gradient_precision to 'unmodified' for {args.optimizer} optimizer, as fp32 gradients are not required."
)
args.gradient_precision = "unmodified"

if torch.backends.mps.is_available():
Expand Down Expand Up @@ -2165,7 +2153,7 @@ def parse_cmdline_args(input_args=None):
or args.flux_fast_schedule
):
if not args.flux_fast_schedule:
logger.error("Schnell requires --flux_fast_schedule.")
error_log("Schnell requires --flux_fast_schedule.")
sys.exit(1)
flux_version = "schnell"
model_max_seq_length = 256
Expand Down Expand Up @@ -2202,11 +2190,11 @@ def parse_cmdline_args(input_args=None):
)

if args.flux_guidance_mode == "mobius":
logger.warning(
warning_log(
"Mobius training is only for the most elite. Pardon my English, but this is not for those who don't like to destroy something beautiful every now and then. If you feel perhaps this is not for you, please consider using a different guidance mode."
)
if args.flux_guidance_min < 1.0:
logger.warning(
warning_log(
"Flux minimum guidance value for Mobius training is 1.0. Updating value.."
)
args.flux_guidance_min = 1.0
Expand Down Expand Up @@ -2339,7 +2327,7 @@ def parse_cmdline_args(input_args=None):
)
args.use_dora = False
else:
logger.warning(
warning_log(
"DoRA support is experimental and not very thoroughly tested."
)
args.lora_initialisation_style = "default"
Expand All @@ -2350,7 +2338,7 @@ def parse_cmdline_args(input_args=None):
args.data_backend_config = os.path.join(
StateTracker.get_config_path(), "multidatabackend.json"
)
logger.warning(
warning_log(
f"No data backend config provided. Using default config at {args.data_backend_config}."
)

Expand Down
19 changes: 19 additions & 0 deletions helpers/training/default_settings/safety_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,22 @@ def safety_check(args, accelerator):
raise FileNotFoundError(
f"User prompt library not found at {args.user_prompt_library}. Please check the path and try again."
)


# optimizer memory limit check for SOAP w/ 24G
if accelerator.device.type == "cuda" and accelerator.is_main_process:
import subprocess
output = subprocess.check_output(
[
"nvidia-smi",
"--query-gpu=memory.total",
"--format=csv,noheader,nounits",
]
).split(b"\n")[get_rank()]
total_memory = int(output.decode().strip()) / 1024
from math import ceil
total_memory_gb = ceil(total_memory)
if total_memory_gb < 32 and args.optimizer == "soap":
logger.warning(
f"Your GPU has {total_memory_gb}GB of memory. The SOAP optimiser may require more than this."
)
20 changes: 19 additions & 1 deletion helpers/training/optimizer_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
is_optimi_available = False
from helpers.training.optimizers.adamw_bfloat16 import AdamWBF16
from helpers.training.optimizers.adamw_schedulefree import AdamWScheduleFreeKahan
from helpers.training.optimizers.soap import SOAP

try:
from optimum.quanto import QTensor
Expand Down Expand Up @@ -235,8 +236,25 @@
},
"class": optimi.SGD,
},
"soap": {
"precision": "any",
"gradient_precision": "fp32",
"default_settings": {
"betas": (0.95, 0.95),
"shampoo_beta": -1,
"eps": 1e-8,
"weight_decay": 0.01,
"precondition_frequency": 10,
"max_precond_dim": 10000,
"merge_dims": False,
"precondition_1d": False,
"normalize_grads": False,
"data_format": "channels_first",
"correct_bias": True,
},
"class": SOAP,
},
}

args_to_optimizer_mapping = {
"use_adafactor_optimizer": "adafactor",
"use_prodigy_optimizer": "prodigy",
Expand Down
Loading

0 comments on commit 657a00c

Please sign in to comment.