diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 7d116d2679181d..8d56d82d77d972 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -168,6 +168,7 @@ is_torch_npu_available, is_torch_xla_available, is_torch_xpu_available, + is_torchao_available, logging, strtobool, ) @@ -1451,7 +1452,23 @@ def optimizer_hook(param): "gradient_clipping": float(optim_args.get("gradient_clipping", 1.0)), } ) + elif args.optim == OptimizerNames.ADAMW_TORCH_4BIT: + if not is_torchao_available() or version.parse(importlib.metadata.version("torchao")) < version.parse( + "0.4.0" + ): + raise ImportError( + "You need to have `torchao>=0.4.0` in order to use torch 4-bit optimizers." + "Install it with `pip install torchao` or follow the instructions here: https://github.com/pytorch/ao" + ) + if version.parse(importlib.metadata.version("torch")) < version.parse("2.3"): + raise ImportError( + "You need to have `torch>=2.3` in order to use torch 4-bit optimizers. " + "Install it with `pip install --upgrade torch`" + ) + from torchao.prototype.low_bit_optim import AdamW4bit + optimizer_cls = AdamW4bit + optimizer_kwargs.update(adam_kwargs) else: raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}") return optimizer_cls, optimizer_kwargs diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index ec8f575b6c3ea3..a270754a26abe5 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -154,6 +154,7 @@ class OptimizerNames(ExplicitEnum): ADAMW_APEX_FUSED = "adamw_apex_fused" ADAFACTOR = "adafactor" ADAMW_ANYPRECISION = "adamw_anyprecision" + ADAMW_TORCH_4BIT = "adamw_torch_4bit" SGD = "sgd" ADAGRAD = "adagrad" ADAMW_BNB = "adamw_bnb_8bit" diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index ca133a277c41b5..3a525befce0cc0 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -99,6 +99,7 @@ is_apex_available, is_bitsandbytes_available, is_safetensors_available, + is_torchao_available, is_torchdistx_available, ) from transformers.utils.hp_naming import TrialShortNamer @@ -4210,6 +4211,16 @@ def hp_name(trial): dict(default_adam_kwargs, **default_anyprecision_kwargs), ) ) + if is_torchao_available(): + import torchao + + optim_test_params.append( + ( + TrainingArguments(optim=OptimizerNames.ADAMW_TORCH_4BIT, output_dir="None"), + torchao.prototype.low_bit_optim.AdamW4bit, + default_adam_kwargs, + ) + ) @require_torch