Skip to content

Commit

Permalink
FEAT / Trainer: Add adamw 4bit optimizer (huggingface#31865)
Browse files Browse the repository at this point in the history
* add 4bit optimizer

* style

* fix msg

* style

* add qgalore

* Revert "add qgalore"

This reverts commit 25278e8.

* style

* version check
  • Loading branch information
SunMarc authored and zucchini-nlp committed Aug 30, 2024
1 parent 4852055 commit 53cd61c
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 0 deletions.
17 changes: 17 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@
is_torch_npu_available,
is_torch_xla_available,
is_torch_xpu_available,
is_torchao_available,
logging,
strtobool,
)
Expand Down Expand Up @@ -1451,7 +1452,23 @@ def optimizer_hook(param):
"gradient_clipping": float(optim_args.get("gradient_clipping", 1.0)),
}
)
elif args.optim == OptimizerNames.ADAMW_TORCH_4BIT:
if not is_torchao_available() or version.parse(importlib.metadata.version("torchao")) < version.parse(
"0.4.0"
):
raise ImportError(
"You need to have `torchao>=0.4.0` in order to use torch 4-bit optimizers."
"Install it with `pip install torchao` or follow the instructions here: https://github.com/pytorch/ao"
)
if version.parse(importlib.metadata.version("torch")) < version.parse("2.3"):
raise ImportError(
"You need to have `torch>=2.3` in order to use torch 4-bit optimizers. "
"Install it with `pip install --upgrade torch`"
)
from torchao.prototype.low_bit_optim import AdamW4bit

optimizer_cls = AdamW4bit
optimizer_kwargs.update(adam_kwargs)
else:
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
return optimizer_cls, optimizer_kwargs
Expand Down
1 change: 1 addition & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ class OptimizerNames(ExplicitEnum):
ADAMW_APEX_FUSED = "adamw_apex_fused"
ADAFACTOR = "adafactor"
ADAMW_ANYPRECISION = "adamw_anyprecision"
ADAMW_TORCH_4BIT = "adamw_torch_4bit"
SGD = "sgd"
ADAGRAD = "adagrad"
ADAMW_BNB = "adamw_bnb_8bit"
Expand Down
11 changes: 11 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
is_apex_available,
is_bitsandbytes_available,
is_safetensors_available,
is_torchao_available,
is_torchdistx_available,
)
from transformers.utils.hp_naming import TrialShortNamer
Expand Down Expand Up @@ -4210,6 +4211,16 @@ def hp_name(trial):
dict(default_adam_kwargs, **default_anyprecision_kwargs),
)
)
if is_torchao_available():
import torchao

optim_test_params.append(
(
TrainingArguments(optim=OptimizerNames.ADAMW_TORCH_4BIT, output_dir="None"),
torchao.prototype.low_bit_optim.AdamW4bit,
default_adam_kwargs,
)
)


@require_torch
Expand Down

0 comments on commit 53cd61c

Please sign in to comment.