Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

minimum learning rate should be allowed to set in lr schedulers #26209

Closed
annahung31 opened this issue Sep 18, 2023 · 12 comments · Fixed by #29341
Closed

minimum learning rate should be allowed to set in lr schedulers #26209

annahung31 opened this issue Sep 18, 2023 · 12 comments · Fixed by #29341
Labels
Feature request Request for a new feature

Comments

@annahung31
Copy link

annahung31 commented Sep 18, 2023

Feature request

In current lr schedulers provided in optimization.py, the minimum learning rate is always 0.0.
We could add one more input parameter like "min_lr" to let user defind the minimum learning rate.

Take _get_linear_schedule_with_warmup_lr_lambda as an example:
Original:

def _get_linear_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int, num_training_steps: int):
    if current_step < num_warmup_steps:
        return float(current_step) / float(max(1, num_warmup_steps))
    return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))

We can change it into:

def _get_linear_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int, num_training_steps: int, min_lr: float = 0.0):
    if current_step < num_warmup_steps:
        return float(current_step) / float(max(1, num_warmup_steps))
    return max(min_lr, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))

Motivation

In some papers, they mentioned about their lr scheduling. Take LIMA as an example:

Without warmup steps, we set the initial learning rate to 1e − 5 and linearly decaying to 1e − 6 by the end of training.

To reproduce the experiment using their recipe, I need to rewrite the scheduler (and the all related functions like get_scheduler/create_scheduler) in the trainer, that makes the code really ugly.
So I think it might be good to have this kind of feature to make trainer more flexible.

Your contribution

I can submit a PR for this feature.

@ArthurZucker ArthurZucker added the Feature request Request for a new feature label Sep 18, 2023
@ArthurZucker
Copy link
Collaborator

Is the following not pretty much what you are looking for:

        min_lr_ratio (`float`, *optional*, defaults to 0):
            The final learning rate at the end of the linear decay will be `init_lr * min_lr_ratio`.

@annahung31
Copy link
Author

Yes, it can work for me. This argument is in optimization_tf.py, can we have it in optimization.py?

@baichuanzhou
Copy link

I ran into the same kind of problem today. I think adding that option is a good idea.

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Sep 26, 2023

Would one of you like to open a PR for this? 🤗

@annahung31
Copy link
Author

Yeah, let me do that!

@Drzhivago264
Copy link

Drzhivago264 commented Oct 5, 2023

Wait I face the same problem, but simply changing 0.0 to min_lr_ratio will not work.
According to the papers, the lr must be reduced slowly considering all the learning steps.
If we only change the ratio, the the model will reduce lr until it reach min learning rate after x numbers of steps then will keeps training with that lr until it finishes.

This is my custom trainer, I dont know how to replicates the behaviors in papers yet, please correct me if I am wrong or if I misunderstanding your implementation.
`class CustomTrainer(Trainer):
def init(self, *args, **kwargs):
super().init(*args, **kwargs)

def create_optimizer_and_scheduler(self, num_training_steps):
    self.optimizer = AdamW(self.model.parameters(),
                           lr=self.args.learning_rate,
                           weight_decay=self.args.weight_decay,
                           eps=self.args.adam_epsilon)
    
    def CUSTOM_get_cosine_schedule_with_warmup_lr_lambda(
        current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float
    ):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(0.1, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
    
    def CUSTOM_get_cosine_schedule_with_warmup(
        optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1
    ):

        lr_lambda = partial(
            CUSTOM_get_cosine_schedule_with_warmup_lr_lambda,
            num_warmup_steps=num_warmup_steps,
            num_training_steps=num_training_steps,
            num_cycles=num_cycles,
        )
        return LambdaLR(optimizer, lr_lambda, last_epoch)

    self.lr_scheduler =  CUSTOM_get_cosine_schedule_with_warmup(
        self.optimizer, self.args.warmup_steps, num_training_steps)`

@annahung31
Copy link
Author

Yes, you are right.
After I encountered this issue, my strategy is to add one more argument called milestone (I'm still thinking about a better naming), it represents the training step that where mininum learning rate locates. Then keeps training with that lr until it finishes.
So my implementation for linear scheduler is like:

def _get_linear_schedule_with_warmup_lr_lambda(
   current_step: int,
    *,
    num_warmup_steps: int,
    num_training_steps: int,
    min_lr_ratio: float,
    milestone: int,
):

    if current_step < num_warmup_steps:
        return float(current_step) / float(max(1, num_warmup_steps))
    if current_step <= milestone:
        return 1 - (current_step / milestone) * (1 - min_lr_ratio)
    else:
        return min_lr_ratio

def get_linear_schedule_with_warmup(
    optimizer: Any,
    num_warmup_steps: int,
    num_training_steps: int,
    milestone: Optional[int],
    min_lr_ratio: float = 0.0,
    last_epoch: int = -1,
):
    """
    Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to initial learning * min_lr_ratio, after
    a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.

    Args:
        last_epoch (`int`, *optional*, defaults to -1):
            The index of the last epoch when resuming training.

    Return:
        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
    """

    milestone = (
        milestone if milestone is not None else num_training_steps - 1
    )  # to make the last learning rate become 0.

    lr_lambda = partial(
        _get_linear_schedule_with_warmup_lr_lambda,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_training_steps,
        min_lr_ratio=min_lr_ratio,
        milestone=milestone,
    )

    return LambdaLR(optimizer, lr_lambda, last_epoch)

If we want the lr be reduced slowly considering all the learning steps, we can set milestone = num_training_steps - 1. This is set to default so that the behavior align with the original code.

As for the cosine scheduler, I found that I can change the parameter num_cycles to control the curve,
The default is num_cycles=0.5, so the lr will reach the end of 0.0 at the end of training process. We can change it to 0.25, then the lr will reach to half of the initial learning rate at the end.

However, this parameter cannot be reached from Trainer setting. For me, currently the workaround is to add this parameter like:

train_args = TrainingArguments(....)
train_args.num_cycles = 0.25

#in custom trainer:
class MyTrainer(Trainer):
    ....

    def create_trainer(self, num_training_steps):
        self.lr_scheduler = get_scheduler(
        ....
        num_cycles = self.args.num_cycles
        )

I'm thinking about add it into TrainingArguments directly.

Does this implementation makes sense to you? Any suggestion is welcome.

@Drzhivago264
Copy link

You implementation is almost perfect. But the line is crack at your milestone.
I came up with this for linear scheduler. I think there must be better way, but we have all variables needed to calculate the desired LR smoothly.
Screenshot from 2023-10-07 04-57-19

@desperadoola
Copy link

Just replace max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))) to

max( min_lr_ratio, min_lr_ratio + (1-min_lr_ratio) * float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))

@npielawski
Copy link

Any update?

@young-chao
Copy link

young-chao commented Feb 5, 2024 via email

@young-chao
Copy link

young-chao commented Mar 26, 2024 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Feature request Request for a new feature
Projects
None yet
Development

Successfully merging a pull request may close this issue.

7 participants