diff --git a/src/transformers/optimization.py b/src/transformers/optimization.py index 4a92b18a30314b..ca316d19d3e17a 100644 --- a/src/transformers/optimization.py +++ b/src/transformers/optimization.py @@ -420,6 +420,12 @@ class Adafactor(Optimizer): Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None) + When using ``lr=None`` with :class:`~transformers.Trainer` you will most likely need to use :class:`~transformers.optimization.AdafactorSchedule` scheduler as following:: + + from transformers.optimization import Adafactor, AdafactorSchedule + optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None) + lr_scheduler = AdafactorSchedule(optimizer) + trainer = Trainer(..., optimizers=(optimizer, lr_scheduler)) Usage:: @@ -588,3 +594,52 @@ def step(self, closure=None): p.data.copy_(p_data_fp32) return loss + + +class AdafactorSchedule(LambdaLR): + """ + Since :class:`~transformers.optimization.Adafactor` performs its own scheduling, if the training loop relies on a + scheduler (e.g., for logging), this class creates a proxy object that retrieves the current lr values from the + optimizer. + + It returns ``initial_lr`` during startup and the actual ``lr`` during stepping. + """ + + def __init__(self, optimizer, initial_lr=0.0): + def lr_lambda(_): + return initial_lr + + for group in optimizer.param_groups: + group["initial_lr"] = initial_lr + super().__init__(optimizer, lr_lambda) + for group in optimizer.param_groups: + del group["initial_lr"] + + def get_lr(self): + opt = self.optimizer + lrs = [ + opt._get_lr(group, opt.state[group["params"][0]]) + for group in opt.param_groups + if group["params"][0].grad is not None + ] + if len(lrs) == 0: + lrs = self.base_lrs # if called before stepping + return lrs + + +def get_adafactor_schedule(optimizer, initial_lr=0.0): + """ + Get a proxy schedule for :class:`~transformers.optimization.Adafactor` + + Args: + optimizer (:class:`~torch.optim.Optimizer`): + The optimizer for which to schedule the learning rate. + initial_lr (:obj:`float`, `optional`, defaults to 0.0): + Initial lr + + Return: + :class:`~transformers.optimization.Adafactor` proxy schedule object. + + + """ + return AdafactorSchedule(optimizer, initial_lr) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 3610f98d819f9d..e5c2bf7b88bf3c 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -589,6 +589,25 @@ def test_custom_optimizer(self): self.assertFalse(torch.allclose(trainer.model.b, b)) self.assertEqual(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 1.0) + @require_torch + def test_adafactor_lr_none(self): + # test the special case where lr=None, since Trainer can't not have lr_scheduler + + from transformers.optimization import Adafactor, AdafactorSchedule + + train_dataset = RegressionDataset() + args = TrainingArguments("./regression") + model = RegressionModel() + optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None) + lr_scheduler = AdafactorSchedule(optimizer) + trainer = Trainer(model, args, train_dataset=train_dataset, optimizers=(optimizer, lr_scheduler)) + trainer.train() + + (a, b) = self.default_trained_model + self.assertFalse(torch.allclose(trainer.model.a, a)) + self.assertFalse(torch.allclose(trainer.model.b, b)) + self.assertGreater(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 0) + def test_model_init(self): train_dataset = RegressionDataset() args = TrainingArguments("./regression", learning_rate=0.1)