Skip to content

Commit

Permalink
Merge pull request #215 from OpenAccess-AI-Collective/adamw-hyperpara…
Browse files Browse the repository at this point in the history
…ms-cfg

support adamw and grad norm hyperparams
  • Loading branch information
winglian authored Jun 15, 2023
2 parents bfd4e3e + cd4fb19 commit 1b31d75
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 0 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,12 @@ log_sweep_max_lr:
optimizer:
# specify weight decay
weight_decay:
# adamw hyperparams
adam_beta1:
adam_beta2:
adam_epsilon:
# Gradient clipping max norm
max_grad_norm:
# whether to bettertransformers
flash_optimum:
Expand Down
9 changes: 9 additions & 0 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,15 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
# TODO search Path("./") for one
training_arguments_kwargs["deepspeed"] = "./ds_config.json"

if cfg.adam_beta1:
training_arguments_kwargs["adam_beta1"] = cfg.adam_beta1
if cfg.adam_beta2:
training_arguments_kwargs["adam_beta2"] = cfg.adam_beta2
if cfg.adam_epsilon:
training_arguments_kwargs["adam_epsilon"] = cfg.adam_epsilon
if cfg.max_grad_norm:
training_arguments_kwargs["max_grad_norm"] = cfg.max_grad_norm

training_args = transformers.TrainingArguments(
per_device_train_batch_size=cfg.micro_batch_size,
per_device_eval_batch_size=cfg.eval_batch_size
Expand Down
5 changes: 5 additions & 0 deletions src/axolotl/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ def validate_config(cfg):
"You probably want to disable group_by_length as it will force a streamed dataset to download completely."
)

if any([cfg.adamw_beta1, cfg.adamw_beta2, cfg.adamw_epsilon]) and (
not cfg.optimizer or "adamw" not in cfg.optimizer
):
logging.warning("adamw hyperparameters found, but no adamw optimizer set")

# TODO
# MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25
Expand Down
50 changes: 50 additions & 0 deletions tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,3 +263,53 @@ def test_flash_optimum(self):

with pytest.raises(ValueError, match=regex_exp):
validate_config(cfg)

def test_adamw_hyperparams(self):
cfg = DictDefault(
{
"optimizer": None,
"adamw_epsilon": 0.0001,
}
)

with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert any(
"adamw hyperparameters found, but no adamw optimizer set"
in record.message
for record in self._caplog.records
)

cfg = DictDefault(
{
"optimizer": "adafactor",
"adamw_beta1": 0.0001,
}
)

with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert any(
"adamw hyperparameters found, but no adamw optimizer set"
in record.message
for record in self._caplog.records
)

cfg = DictDefault(
{
"optimizer": "adamw_bnb_8bit",
"adamw_beta1": 0.0001,
"adamw_beta2": 0.0001,
"adamw_epsilon": 0.0001,
}
)

validate_config(cfg)

cfg = DictDefault(
{
"optimizer": "adafactor",
}
)

validate_config(cfg)

0 comments on commit 1b31d75

Please sign in to comment.