From 196d35ccfcb3cb07b18f8fd12208c74312c9ecfa Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Wed, 25 Sep 2024 13:07:21 -0400 Subject: [PATCH] Add AdEMAMix optimizer (#33682) * Add AdEMAMix optimizer * Fix test * Update tests/trainer/test_trainer.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> --- src/transformers/trainer.py | 31 ++++++ src/transformers/training_args.py | 6 +- tests/trainer/test_trainer.py | 165 ++++++++++++++++++++++++++++++ 3 files changed, 201 insertions(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index e0a49ee5795e04..216d5cd4296008 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1237,6 +1237,10 @@ def get_optimizer_cls_and_kwargs( OptimizerNames.ADAMW_8BIT, OptimizerNames.PAGED_ADAMW, OptimizerNames.PAGED_ADAMW_8BIT, + OptimizerNames.ADEMAMIX, + OptimizerNames.ADEMAMIX_8BIT, + OptimizerNames.PAGED_ADEMAMIX, + OptimizerNames.PAGED_ADEMAMIX_8BIT, OptimizerNames.LION, OptimizerNames.LION_8BIT, OptimizerNames.PAGED_LION, @@ -1266,6 +1270,33 @@ def get_optimizer_cls_and_kwargs( # Above we pass all `adam_kwargs` to the optimizer, here # we only pass `optim_args` which can be passed by the user. additional_optim_kwargs = optim_args + elif "ademamix" in args.optim: + if is_bitsandbytes_available() and version.parse( + importlib.metadata.version("bitsandbytes") + ) < version.parse("0.44.0"): + raise ValueError( + "The AdEMAMix optimizer is not supported by your current version of `bitsandbytes`. " + "Please install `bitsandbytes` >= 0.44.0." + ) + + from bitsandbytes.optim import AdEMAMix + + optimizer_cls = AdEMAMix + additional_optim_kwargs = { + "betas": ( + float(optim_args.get("beta1", args.adam_beta1)), + float(optim_args.get("beta2", args.adam_beta2)), + float(optim_args.get("beta3", 0.9999)), + ), + "alpha": float(optim_args.get("alpha", 5.0)), + "eps": float(optim_args.get("eps", args.adam_epsilon)), + } + + if "t_alpha" in optim_args: + additional_optim_kwargs["t_alpha"] = int(optim_args["t_alpha"]) + + if "t_beta3" in optim_args: + additional_optim_kwargs["t_beta3"] = int(optim_args["t_beta3"]) bnb_kwargs = {"optim_bits": optim_bits} if "rmsprop" not in args.optim: diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 02413c28583256..596917928350bf 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -155,14 +155,18 @@ class OptimizerNames(ExplicitEnum): ADAFACTOR = "adafactor" ADAMW_ANYPRECISION = "adamw_anyprecision" ADAMW_TORCH_4BIT = "adamw_torch_4bit" + ADEMAMIX = "ademamix" SGD = "sgd" ADAGRAD = "adagrad" ADAMW_BNB = "adamw_bnb_8bit" ADAMW_8BIT = "adamw_8bit" # just an alias for adamw_bnb_8bit + ADEMAMIX_8BIT = "ademamix_8bit" LION_8BIT = "lion_8bit" LION = "lion_32bit" PAGED_ADAMW = "paged_adamw_32bit" PAGED_ADAMW_8BIT = "paged_adamw_8bit" + PAGED_ADEMAMIX = "paged_ademamix_32bit" + PAGED_ADEMAMIX_8BIT = "paged_ademamix_8bit" PAGED_LION = "paged_lion_32bit" PAGED_LION_8BIT = "paged_lion_8bit" RMSPROP = "rmsprop" @@ -618,7 +622,7 @@ class TrainingArguments: "adafactor". See `OptimizerNames` in [training_args.py](https://github.com/huggingface/transformers/blob/main/src/transformers/training_args.py) for a full list of optimizers. optim_args (`str`, *optional*): - Optional arguments that are supplied to AnyPrecisionAdamW. + Optional arguments that are supplied to optimizers such as AnyPrecisionAdamW, AdEMAMix, and GaLore. group_by_length (`bool`, *optional*, defaults to `False`): Whether or not to group together samples of roughly the same length in the training dataset (to minimize padding applied and be more efficient). Only useful if applying dynamic padding. diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 14014e4a0947cd..0035ff7de8ba97 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -15,6 +15,7 @@ import dataclasses import gc +import importlib import json import math import os @@ -32,6 +33,7 @@ import numpy as np from huggingface_hub import HfFolder, ModelCard, create_branch, delete_repo, list_repo_commits, list_repo_files +from packaging import version from parameterized import parameterized from requests.exceptions import HTTPError @@ -1091,6 +1093,40 @@ def test_rmsprop_bnb(self): # Check that it trains without errors trainer.train() + @require_bitsandbytes + def test_ademamix_bnb(self): + config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4) + tiny_gpt2 = GPT2LMHeadModel(config) + x = torch.randint(0, 100, (128,)) + train_dataset = RepeatDataset(x) + + with tempfile.TemporaryDirectory() as tmpdir: + # Trainer without inf/nan filter + args = TrainingArguments( + tmpdir, learning_rate=1e-9, logging_steps=5, logging_nan_inf_filter=False, optim="ademamix" + ) + trainer = Trainer(tiny_gpt2, args, train_dataset=train_dataset) + + # Check that it trains without errors + trainer.train() + + @require_bitsandbytes + def test_ademamix_bnb_8bit(self): + config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4) + tiny_gpt2 = GPT2LMHeadModel(config) + x = torch.randint(0, 100, (128,)) + train_dataset = RepeatDataset(x) + + with tempfile.TemporaryDirectory() as tmpdir: + # Trainer without inf/nan filter + args = TrainingArguments( + tmpdir, learning_rate=1e-9, logging_steps=5, logging_nan_inf_filter=False, optim="ademamix_8bit" + ) + trainer = Trainer(tiny_gpt2, args, train_dataset=train_dataset) + + # Check that it trains without errors + trainer.train() + @require_bitsandbytes def test_rmsprop_bnb_8bit(self): config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4) @@ -4187,6 +4223,13 @@ def hp_name(trial): "lr": TrainingArguments.learning_rate, } + default_ademamix_kwargs = { + "betas": (TrainingArguments.adam_beta1, TrainingArguments.adam_beta2, 0.9999), + "alpha": 5.0, + "eps": TrainingArguments.adam_epsilon, + "lr": TrainingArguments.learning_rate, + } + default_anyprecision_kwargs = { "use_kahan_summation": False, "momentum_dtype": torch.float32, @@ -4291,6 +4334,36 @@ def hp_name(trial): ) ) + if version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.44.0"): + optim_test_params.append( + ( + TrainingArguments(optim=OptimizerNames.ADEMAMIX, output_dir="None"), + bnb.optim.AdEMAMix, + default_ademamix_kwargs, + ) + ) + optim_test_params.append( + ( + TrainingArguments(optim=OptimizerNames.ADEMAMIX_8BIT, output_dir="None"), + bnb.optim.AdEMAMix, + default_ademamix_kwargs, + ) + ) + optim_test_params.append( + ( + TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX_8BIT, output_dir="None"), + bnb.optim.AdEMAMix, + default_ademamix_kwargs, + ) + ) + optim_test_params.append( + ( + TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX, output_dir="None"), + bnb.optim.AdEMAMix, + default_ademamix_kwargs, + ) + ) + if is_torchdistx_available(): import torchdistx @@ -4420,6 +4493,62 @@ def test_bnb_paged_adam8bit(self): default_adam_kwargs, ) + def test_bnb_ademamix(self): + mock = Mock() + modules = { + "bitsandbytes": mock, + "bitsandbytes.optim": mock.optim, + "bitsandbytes.optim.AdEMAMix": mock.optim.AdEMAMix, + } + with patch.dict("sys.modules", modules): + self.check_optim_and_kwargs( + TrainingArguments(optim=OptimizerNames.ADEMAMIX, output_dir="None"), + mock.optim.AdEMAMix, + default_ademamix_kwargs, + ) + + def test_bnb_ademamix8bit(self): + mock = Mock() + modules = { + "bitsandbytes": mock, + "bitsandbytes.optim": mock.optim, + "bitsandbytes.optim.AdEMAMix": mock.optim.AdEMAMix, + } + with patch.dict("sys.modules", modules): + self.check_optim_and_kwargs( + TrainingArguments(optim=OptimizerNames.ADEMAMIX_8BIT, output_dir="None"), + mock.optim.AdEMAMix, + default_ademamix_kwargs, + ) + + def test_bnb_paged_ademamix(self): + mock = Mock() + modules = { + "bitsandbytes": mock, + "bitsandbytes.optim": mock.optim, + "bitsandbytes.optim.AdEMAMix": mock.optim.AdEMAMix, + } + with patch.dict("sys.modules", modules): + self.check_optim_and_kwargs( + TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX, output_dir="None"), + mock.optim.AdEMAMix, + default_ademamix_kwargs, + ) + + def test_bnb_paged_ademamix8bit(self): + mock = Mock() + modules = { + "bitsandbytes": mock, + "bitsandbytes.optim": mock.optim, + "bitsandbytes.optim.AdEMAMix": mock.optim.AdEMAMix, + } + with patch.dict("sys.modules", modules): + self.check_optim_and_kwargs( + TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX_8BIT, output_dir="None"), + mock.optim.AdEMAMix, + default_ademamix_kwargs, + ) + def test_bnb_lion(self): mock = Mock() modules = { @@ -4503,6 +4632,42 @@ def test_bnb_paged_adam8bit_no_bnb(self): with self.assertRaises(ValueError): Trainer.get_optimizer_cls_and_kwargs(args) + def test_bnb_ademamix_no_bnb(self): + args = TrainingArguments(optim=OptimizerNames.ADEMAMIX, output_dir="None") + + # Pretend that bnb does not exist, even if installed. By setting bnb to None, importing + # bnb will fail even if `bitsandbytes` is installed. + with patch.dict("sys.modules", {"bitsandbytes.optim": None}): + with self.assertRaises(ValueError): + Trainer.get_optimizer_cls_and_kwargs(args) + + def test_bnb_ademamix8bit_no_bnb(self): + args = TrainingArguments(optim=OptimizerNames.ADEMAMIX_8BIT, output_dir="None") + + # Pretend that bnb does not exist, even if installed. By setting bnb to None, importing + # bnb will fail even if `bitsandbytes` is installed. + with patch.dict("sys.modules", {"bitsandbytes.optim": None}): + with self.assertRaises(ValueError): + Trainer.get_optimizer_cls_and_kwargs(args) + + def test_bnb_paged_ademamix_no_bnb(self): + args = TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX, output_dir="None") + + # Pretend that bnb does not exist, even if installed. By setting bnb to None, importing + # bnb will fail even if `bitsandbytes` is installed. + with patch.dict("sys.modules", {"bitsandbytes.optim": None}): + with self.assertRaises(ValueError): + Trainer.get_optimizer_cls_and_kwargs(args) + + def test_bnb_paged_ademamix8bit_no_bnb(self): + args = TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX_8BIT, output_dir="None") + + # Pretend that bnb does not exist, even if installed. By setting bnb to None, importing + # bnb will fail even if `bitsandbytes` is installed. + with patch.dict("sys.modules", {"bitsandbytes.optim": None}): + with self.assertRaises(ValueError): + Trainer.get_optimizer_cls_and_kwargs(args) + def test_bnb_paged_lion_no_bnb(self): args = TrainingArguments(optim=OptimizerNames.PAGED_LION, output_dir="None")