Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cosine_with_min_lr_schedule_with_warmup_lr_rate scheduler in Trainer #31870

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions src/transformers/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,87 @@ def get_cosine_with_min_lr_schedule_with_warmup(
return LambdaLR(optimizer, lr_lambda, last_epoch)


def _get_cosine_with_min_lr_schedule_with_warmup_lr_rate_lambda(
current_step: int,
*,
num_warmup_steps: int,
num_training_steps: int,
num_cycles: float,
min_lr_rate: float = 0.0,
warmup_lr_rate: float = None,
):
current_step = float(current_step)
num_warmup_steps = float(num_warmup_steps)
num_training_steps = float(num_training_steps)

if current_step < num_warmup_steps:
if warmup_lr_rate is None:
return (current_step + 1.) / max(1., num_warmup_steps)
else:
warmup_lr_rate = float(warmup_lr_rate)
return warmup_lr_rate + (1.0 - warmup_lr_rate) * (current_step) / (max(1, num_warmup_steps - 1))
progress = (current_step - num_warmup_steps + 1.) / (max(1., num_training_steps - num_warmup_steps))
factor = 0.5 * (1.0 + math.cos(math.pi * num_cycles * 2.0 * progress))
factor = factor * (1 - min_lr_rate) + min_lr_rate
return max(0, factor)


def get_cosine_with_min_lr_schedule_with_warmup_lr_rate(
optimizer: Optimizer,
num_warmup_steps: int,
num_training_steps: int,
num_cycles: float = 0.5,
last_epoch: int = -1,
min_lr: float = None,
min_lr_rate: float = None,
warmup_lr_rate: float = None,
):
"""
Create a schedule with a learning rate that decreases following the values of the cosine function between the
initial lr set in the optimizer to min_lr, after a warmup period during which it increases linearly between 0 and the
initial lr set in the optimizer.

Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
num_warmup_steps (`int`):
The number of steps for the warmup phase.
num_training_steps (`int`):
The total number of training steps.
num_cycles (`float`, *optional*, defaults to 0.5):
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
following a half-cosine).
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
min_lr (`float`, *optional*):
The minimum learning rate to reach after the cosine schedule.
min_lr_rate (`float`, *optional*):
The minimum learning rate as a ratio of the initial learning rate. If set, `min_lr` should not be set.
warmup_lr_rate (`float`, *optional*):
The minimum learning rate as a ratio of the start learning rate. If not set, `warmup_lr_rate` will be treated as float(1/num_warmup_steps).

Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""

if min_lr is not None and min_lr_rate is not None:
raise ValueError("Only one of min_lr or min_lr_rate should be set")
elif min_lr is not None:
min_lr_rate = min_lr / optimizer.defaults["lr"]
elif min_lr_rate is None:
raise ValueError("One of min_lr or min_lr_rate should be set through the `lr_scheduler_kwargs`")

lr_lambda = partial(
_get_cosine_with_min_lr_schedule_with_warmup_lr_rate_lambda,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
num_cycles=num_cycles,
min_lr_rate=min_lr_rate,
warmup_lr_rate=warmup_lr_rate,
)
return LambdaLR(optimizer, lr_lambda, last_epoch)


def _get_wsd_scheduler_lambda(
current_step: int,
*,
Expand Down Expand Up @@ -464,6 +545,7 @@ def get_wsd_schedule(
SchedulerType.INVERSE_SQRT: get_inverse_sqrt_schedule,
SchedulerType.REDUCE_ON_PLATEAU: get_reduce_on_plateau_schedule,
SchedulerType.COSINE_WITH_MIN_LR: get_cosine_with_min_lr_schedule_with_warmup,
SchedulerType.COSINE_WARMUP_WITH_MIN_LR: get_cosine_with_min_lr_schedule_with_warmup_lr_rate,
SchedulerType.WARMUP_STABLE_DECAY: get_wsd_schedule,
}

Expand Down
1 change: 1 addition & 0 deletions src/transformers/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,7 @@ class SchedulerType(ExplicitEnum):
INVERSE_SQRT = "inverse_sqrt"
REDUCE_ON_PLATEAU = "reduce_lr_on_plateau"
COSINE_WITH_MIN_LR = "cosine_with_min_lr"
COSINE_WARMUP_WITH_MIN_LR = "cosine_warmup_with_min_lr"
WARMUP_STABLE_DECAY = "warmup_stable_decay"


Expand Down
28 changes: 28 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,34 @@ def test_cosine_with_min_lr_scheduler(self):
trainer.lr_scheduler.step()
self.assertEqual(trainer.lr_scheduler.get_last_lr()[0], 1e-5)

def test_cosine_with_min_lr_schedule_with_warmup_lr_rate(self):
train_dataset = RegressionDataset()
model = RegressionModel()
num_steps, num_warmup_steps = 10, 2
extra_kwargs = {"min_lr": 1e-5} # Non-default arguments
args = TrainingArguments(
"./regression",
lr_scheduler_type="cosine_warmup_with_min_lr",
lr_scheduler_kwargs=extra_kwargs,
learning_rate=0.2,
warmup_steps=num_warmup_steps,
report_to="none",
)
trainer = Trainer(model, args, train_dataset=train_dataset)
trainer.create_optimizer_and_scheduler(num_training_steps=num_steps)

# Checking that the scheduler was created
self.assertIsNotNone(trainer.lr_scheduler)

# Check the last learning rate
step_lrs = []
for _ in range(num_steps):
step_lrs.append(trainer.optimizer.param_groups[0]["lr"])
trainer.lr_scheduler.step()
self.assertEqual(step_lrs[0], 0.1)
self.assertEqual(step_lrs[1], 0.2)
self.assertEqual(step_lrs[-1], 1e-05)

def test_reduce_lr_on_plateau_args(self):
# test passed arguments for a custom ReduceLROnPlateau scheduler
train_dataset = RegressionDataset(length=64)
Expand Down
Loading