Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SOAP optimiser; int4 fixes for 4090 #1006

Merged
merged 9 commits into from
Sep 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from helpers.training.optimizer_param import optimizer_choices

bf16_only_optims = [
key for key, value in optimizer_choices.items() if value["precision"] == "bf16"
key for key, value in optimizer_choices.items() if value.get("precision", "any") == "bf16"
]
any_precision_optims = [
key for key, value in optimizer_choices.items() if value["precision"] == "any"
key for key, value in optimizer_choices.items() if value.get("precision", "any") == "any"
]
model_classes = {
"full": [
Expand Down
26 changes: 7 additions & 19 deletions helpers/configuration/cmd_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1981,12 +1981,6 @@ def parse_cmdline_args(input_args=None):
f"When using --resolution_type=pixel, --target_downsample_size must be at least 512 pixels. You may have accidentally entered {args.target_downsample_size} megapixels, instead of pixels."
)

if "int4" in args.base_model_precision and torch.cuda.is_available():
print_on_main_thread(
"WARNING: int4 precision is ONLY supported on A100 and H100 or newer devices. Waiting 10 seconds to continue.."
)
time.sleep(10)

model_is_bf16 = (
args.base_model_precision == "no_change"
and (args.mixed_precision == "bf16" or torch.backends.mps.is_available())
Expand All @@ -2009,19 +2003,13 @@ def parse_cmdline_args(input_args=None):
f"Model is not using bf16 precision, but the optimizer {chosen_optimizer} requires it."
)
if is_optimizer_grad_fp32(args.optimizer):
print(
"[WARNING] Using a low-precision optimizer that requires fp32 gradients. Training will run more slowly."
warning_log(
"Using an optimizer that requires fp32 gradients. Training will potentially run more slowly."
)
if args.gradient_precision != "fp32":
print(
f"[WARNING] Overriding gradient_precision to 'fp32' for {args.optimizer} optimizer."
)
args.gradient_precision = "fp32"
else:
if args.gradient_precision == "fp32":
print(
f"[WARNING] Overriding gradient_precision to 'unmodified' for {args.optimizer} optimizer, as fp32 gradients are not required."
)
args.gradient_precision = "unmodified"

if torch.backends.mps.is_available():
Expand Down Expand Up @@ -2165,7 +2153,7 @@ def parse_cmdline_args(input_args=None):
or args.flux_fast_schedule
):
if not args.flux_fast_schedule:
logger.error("Schnell requires --flux_fast_schedule.")
error_log("Schnell requires --flux_fast_schedule.")
sys.exit(1)
flux_version = "schnell"
model_max_seq_length = 256
Expand Down Expand Up @@ -2202,11 +2190,11 @@ def parse_cmdline_args(input_args=None):
)

if args.flux_guidance_mode == "mobius":
logger.warning(
warning_log(
"Mobius training is only for the most elite. Pardon my English, but this is not for those who don't like to destroy something beautiful every now and then. If you feel perhaps this is not for you, please consider using a different guidance mode."
)
if args.flux_guidance_min < 1.0:
logger.warning(
warning_log(
"Flux minimum guidance value for Mobius training is 1.0. Updating value.."
)
args.flux_guidance_min = 1.0
Expand Down Expand Up @@ -2339,7 +2327,7 @@ def parse_cmdline_args(input_args=None):
)
args.use_dora = False
else:
logger.warning(
warning_log(
"DoRA support is experimental and not very thoroughly tested."
)
args.lora_initialisation_style = "default"
Expand All @@ -2350,7 +2338,7 @@ def parse_cmdline_args(input_args=None):
args.data_backend_config = os.path.join(
StateTracker.get_config_path(), "multidatabackend.json"
)
logger.warning(
warning_log(
f"No data backend config provided. Using default config at {args.data_backend_config}."
)

Expand Down
19 changes: 19 additions & 0 deletions helpers/training/default_settings/safety_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,22 @@ def safety_check(args, accelerator):
raise FileNotFoundError(
f"User prompt library not found at {args.user_prompt_library}. Please check the path and try again."
)


# optimizer memory limit check for SOAP w/ 24G
if accelerator.device.type == "cuda" and accelerator.is_main_process:
import subprocess
output = subprocess.check_output(
[
"nvidia-smi",
"--query-gpu=memory.total",
"--format=csv,noheader,nounits",
]
).split(b"\n")[get_rank()]
total_memory = int(output.decode().strip()) / 1024
from math import ceil
total_memory_gb = ceil(total_memory)
if total_memory_gb < 32 and args.optimizer == "soap":
logger.warning(
f"Your GPU has {total_memory_gb}GB of memory. The SOAP optimiser may require more than this."
)
20 changes: 19 additions & 1 deletion helpers/training/optimizer_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
is_optimi_available = False
from helpers.training.optimizers.adamw_bfloat16 import AdamWBF16
from helpers.training.optimizers.adamw_schedulefree import AdamWScheduleFreeKahan
from helpers.training.optimizers.soap import SOAP

try:
from optimum.quanto import QTensor
Expand Down Expand Up @@ -235,8 +236,25 @@
},
"class": optimi.SGD,
},
"soap": {
"precision": "any",
"gradient_precision": "fp32",
"default_settings": {
"betas": (0.95, 0.95),
"shampoo_beta": -1,
"eps": 1e-8,
"weight_decay": 0.01,
"precondition_frequency": 10,
"max_precond_dim": 10000,
"merge_dims": False,
"precondition_1d": False,
"normalize_grads": False,
"data_format": "channels_first",
"correct_bias": True,
},
"class": SOAP,
},
}

args_to_optimizer_mapping = {
"use_adafactor_optimizer": "adafactor",
"use_prodigy_optimizer": "prodigy",
Expand Down
Loading
Loading