From a93bdffb4769cf5d08789b04125573a73ab0e950 Mon Sep 17 00:00:00 2001 From: Zhong Hui Date: Tue, 14 Jun 2022 23:06:52 +0800 Subject: [PATCH] [Trainer] Support constant and consine lr scheduler (#2511) * support constant and consine lr scheduler * fix doc * delete * add doc --- docs/trainer.md | 2 +- paddlenlp/trainer/trainer_base.py | 11 +- paddlenlp/trainer/trainer_utils.py | 164 ++++++++++++++++++++++++++++- paddlenlp/trainer/training_args.py | 5 +- 4 files changed, 170 insertions(+), 12 deletions(-) diff --git a/docs/trainer.md b/docs/trainer.md index 6ee736daa3d0..47a047bc23d0 100644 --- a/docs/trainer.md +++ b/docs/trainer.md @@ -285,7 +285,7 @@ Trainer 是一个简单,但功能完整的 Paddle训练和评估模块,并 --lr_scheduler_type 要使用的学习率调度策略。 (`str`, 可选, 默认为 `"linear"`) - The scheduler type to use. (default: linear) + The scheduler type to use. (default: linear) 支持,linear, cosine, constant, constant_with_warmup. --warmup_ratio 用于从 0 到 `learning_rate` 的线性warmup的总训练步骤的比例。(`float`,可选,默认为 0.0) diff --git a/paddlenlp/trainer/trainer_base.py b/paddlenlp/trainer/trainer_base.py index 35ee1c8d89c1..d9bba637752c 100644 --- a/paddlenlp/trainer/trainer_base.py +++ b/paddlenlp/trainer/trainer_base.py @@ -66,6 +66,7 @@ OptimizerNames, PREFIX_CHECKPOINT_DIR, get_last_checkpoint, + get_scheduler, ) from .trainer_callback import ( CallbackHandler, @@ -919,14 +920,8 @@ def create_scheduler(self, num_training_steps: int): Args: num_training_steps (int): The number of training steps to do. """ - - def get_scheduler(lr_scheduler_type, learning_rate, num_warmup_steps, - num_training_steps): - # TODO @ZHUI support others - return LinearDecayWithWarmup(learning_rate, num_training_steps, - num_warmup_steps) - - warmup = self.args.warmup_steps if self.args.warmup_steps > 0 else self.args.warmup_ratio + warmup = self.args.warmup_steps if self.args.warmup_steps > 0 else int( + self.args.warmup_ratio * num_training_steps) if self.lr_scheduler is None: self.lr_scheduler = get_scheduler( diff --git a/paddlenlp/trainer/trainer_utils.py b/paddlenlp/trainer/trainer_utils.py index 9f95ebbf5f39..8f04a421b564 100644 --- a/paddlenlp/trainer/trainer_utils.py +++ b/paddlenlp/trainer/trainer_utils.py @@ -28,6 +28,7 @@ from typing import Dict, NamedTuple, Optional, Tuple, Union import numpy as np +from paddle.optimizer.lr import LambdaDecay __all__ = [ "TrainOutput", @@ -38,6 +39,7 @@ "set_seed", "speed_metrics", "get_last_checkpoint", + "get_scheduler", ] @@ -178,12 +180,170 @@ def speed_metrics(split, start_time, num_samples=None, num_steps=None): class SchedulerType(ExplicitEnum): LINEAR = "linear" COSINE = "cosine" - COSINE_WITH_RESTARTS = "cosine_with_restarts" - POLYNOMIAL = "polynomial" CONSTANT = "constant" CONSTANT_WITH_WARMUP = "constant_with_warmup" +def get_constant_schedule(learning_rate: float, last_epoch: int = -1): + """ + Create a schedule with a constant learning rate, using the learning rate set in optimizer. + Args: + learning_rate (float) + The initial learning rate. It is a python float number. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + Return: + `paddle.optimizer.lr.LambdaDecay` with the appropriate schedule. + """ + return LambdaDecay(learning_rate, lambda _: 1, last_epoch=last_epoch) + + +def get_constant_schedule_with_warmup(learning_rate: float, + num_warmup_steps: int, + last_epoch: int = -1): + """ + Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate + increases linearly between 0 and the initial lr set in the optimizer. + Args: + learning_rate (float) + The initial learning rate. It is a python float number. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + Return: + `paddle.optimizer.lr.LambdaDecay` with the appropriate schedule. + """ + + def lr_lambda(current_step: int): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1.0, num_warmup_steps)) + return 1.0 + + return LambdaDecay(learning_rate, lr_lambda, last_epoch=last_epoch) + + +def get_linear_schedule_with_warmup(learning_rate: float, + num_warmup_steps, + num_training_steps, + last_epoch=-1): + """ + Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after + a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. + Args: + learning_rate (float) + The initial learning rate. It is a python float number. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + Return: + `paddle.optimizer.lr.LambdaDecay` with the appropriate schedule. + """ + + def lr_lambda(current_step: int): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + return max( + 0.0, + float(num_training_steps - current_step) / + float(max(1, num_training_steps - num_warmup_steps))) + + return LambdaDecay(learning_rate, lr_lambda, last_epoch) + + +def get_cosine_schedule_with_warmup(learning_rate: float, + num_warmup_steps: int, + num_training_steps: int, + num_cycles: float = 0.5, + last_epoch: int = -1): + """ + Create a schedule with a learning rate that decreases following the values of the cosine function between the + initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the + initial lr set in the optimizer. + Args: + learning_rate (float) + The initial learning rate. It is a python float number. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + num_cycles (`float`, *optional*, defaults to 0.5): + The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 + following a half-cosine). + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + Return: + `paddle.optimizer.lr.LambdaDecay` with the appropriate schedule. + """ + + def lr_lambda(current_step): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + progress = float(current_step - num_warmup_steps) / float( + max(1, num_training_steps - num_warmup_steps)) + return max( + 0.0, 0.5 * + (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) + + return LambdaDecay(learning_rate, lr_lambda, last_epoch) + + +TYPE_TO_SCHEDULER_FUNCTION = { + SchedulerType.LINEAR: get_linear_schedule_with_warmup, + SchedulerType.COSINE: get_cosine_schedule_with_warmup, + SchedulerType.CONSTANT: get_constant_schedule, + SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup, +} + + +def get_scheduler( + name: Union[str, SchedulerType], + learning_rate: float, + num_warmup_steps: Optional[int] = None, + num_training_steps: Optional[int] = None, +): + """ + Unified API to get any scheduler from its name. + Args: + name (`str` or `SchedulerType`): + The name of the scheduler to use. + learning_rate (float) + The initial learning rate. It is a python float number. + num_warmup_steps (`int`, *optional*): + The number of warmup steps to do. This is not required by all schedulers (hence the argument being + optional), the function will raise an error if it's unset and the scheduler type requires it. + num_training_steps (`int``, *optional*): + The number of training steps to do. This is not required by all schedulers (hence the argument being + optional), the function will raise an error if it's unset and the scheduler type requires it. + """ + name = SchedulerType(name) + schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] + if name == SchedulerType.CONSTANT: + return schedule_func(learning_rate) + + # All other schedulers require `num_warmup_steps` + if num_warmup_steps is None: + raise ValueError( + f"{name} requires `num_warmup_steps`, please provide that argument." + ) + + if name == SchedulerType.CONSTANT_WITH_WARMUP: + return schedule_func(learning_rate, num_warmup_steps=num_warmup_steps) + + # All other schedulers require `num_training_steps` + if num_training_steps is None: + raise ValueError( + f"{name} requires `num_training_steps`, please provide that argument." + ) + + return schedule_func(learning_rate, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps) + + def _secs2timedelta(secs): """ convert seconds to hh:mm:ss.msec, msecs rounded to 2 decimals diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 4ac9e9ae4a03..5f5d33f3a278 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -322,7 +322,10 @@ class TrainingArguments: ) lr_scheduler_type: str = field( default="linear", - metadata={"help": "The scheduler type to use."}, + metadata={ + "help": + "The scheduler type to use. suppor linear, cosine, constant, constant_with_warmup" + }, ) warmup_ratio: float = field( default=0.0,