From 03657a0bb2cfd8e7cfe207fb7ad427477bdb78d2 Mon Sep 17 00:00:00 2001 From: Justin Zhao Date: Mon, 28 Aug 2023 13:48:19 -0400 Subject: [PATCH] Revert "DRAFT: Revert "Add Cosine Annealing LR scheduler as a decay method (#3507)" (#3545)" This reverts commit feec8a6f82abe37e4d49e52a0160a614e65b93bc. --- ludwig/modules/lr_scheduler.py | 71 ++++++++++++++--- ludwig/schema/lr_scheduler.py | 28 ++++++- ludwig/schema/metadata/configs/trainer.yaml | 14 +++- tests/ludwig/modules/test_lr_scheduler.py | 84 +++++++++++++++++++++ 4 files changed, 183 insertions(+), 14 deletions(-) diff --git a/ludwig/modules/lr_scheduler.py b/ludwig/modules/lr_scheduler.py index ef0f2d257e7..493b93e277a 100644 --- a/ludwig/modules/lr_scheduler.py +++ b/ludwig/modules/lr_scheduler.py @@ -1,9 +1,9 @@ import logging import math -from typing import Any, Dict +from typing import Any, Callable, Dict from torch.optim import Optimizer -from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau +from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, LambdaLR, ReduceLROnPlateau, SequentialLR from ludwig.constants import MINIMIZE, TRAINING, VALIDATION from ludwig.modules.metric_registry import get_metric_objective @@ -166,14 +166,29 @@ def get_schedule_with_warmup( step_info: StepInfo, ) -> LambdaLR: """Creates a learning rate scheduler that updates each training step.""" - decay_fn = decay_registry[config.decay] + schedulers = [] - def lr_lambda(current_step: int): - if current_step < step_info.num_warmup_steps: - return float(current_step) / float(max(1, step_info.num_warmup_steps)) - return decay_fn(current_step, step_info.num_training_steps, step_info.num_warmup_steps, config) + # Warmup scheduler + if step_info.num_warmup_steps > 0: + warmup_scheduler = LambdaLR( + optimizer, + lambda current_step: float(current_step) / float(max(1, step_info.num_warmup_steps)), + last_epoch=-1, + ) + schedulers.append(warmup_scheduler) - return LambdaLR(optimizer, lr_lambda, last_epoch=-1) + # Decay scheduler + decay = config.decay + decay_scheduler = decay_registry[decay](config, optimizer, step_info) + schedulers.append(decay_scheduler) + + if len(schedulers) == 1: + # Only one scheduler, no need to wrap in a SequentialLR + return schedulers[0] + + # Return a SequentialLR that applies the warmup and decay schedulers in order + # with the warmup scheduler only applied for the first num_warmup_steps steps. + return SequentialLR(optimizer, schedulers=schedulers, milestones=[step_info.num_warmup_steps], last_epoch=-1) def no_decay(current_step: int, num_training_steps: int, num_warmup_steps: int, config: LRSchedulerConfig): @@ -181,7 +196,11 @@ def no_decay(current_step: int, num_training_steps: int, num_warmup_steps: int, def linear_decay(current_step: int, num_training_steps: int, num_warmup_steps: int, config: LRSchedulerConfig): - return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))) + return max( + 0.0, + float(num_training_steps - num_warmup_steps - current_step) + / float(max(1, num_training_steps - num_warmup_steps)), + ) def exponential_decay(current_step: int, num_training_steps: int, num_warmup_steps: int, config: LRSchedulerConfig): @@ -194,8 +213,36 @@ def exponential_decay(current_step: int, num_training_steps: int, num_warmup_ste return math.pow(decay_rate, exponent) +def wrap_decay_fn(decay_fn: Callable) -> Callable: + def init_fn(config: LRSchedulerConfig, optimizer: Optimizer, step_info: StepInfo) -> LambdaLR: + return LambdaLR( + optimizer, + lambda current_step: decay_fn( + current_step, step_info.num_training_steps, step_info.num_warmup_steps, config + ), + last_epoch=-1, + ) + + return init_fn + + +def init_cosine_decay( + config: LRSchedulerConfig, + optimizer: Optimizer, + step_info: StepInfo, +) -> CosineAnnealingWarmRestarts: + return CosineAnnealingWarmRestarts( + optimizer, + T_0=config.t_0 or step_info.steps_per_checkpoint, + T_mult=config.t_mult or 1, + eta_min=config.eta_min or 0, + last_epoch=-1, + ) + + decay_registry = { - None: no_decay, - "linear": linear_decay, - "exponential": exponential_decay, + None: wrap_decay_fn(no_decay), + "linear": wrap_decay_fn(linear_decay), + "exponential": wrap_decay_fn(exponential_decay), + "cosine": init_cosine_decay, } diff --git a/ludwig/schema/lr_scheduler.py b/ludwig/schema/lr_scheduler.py index e102274c65d..3bfedab82bf 100644 --- a/ludwig/schema/lr_scheduler.py +++ b/ludwig/schema/lr_scheduler.py @@ -17,7 +17,7 @@ class LRSchedulerConfig(schema_utils.BaseMarshmallowConfig, ABC): """Configuration for learning rate scheduler parameters.""" decay: str = schema_utils.StringOptions( - options=["linear", "exponential"], + options=["linear", "exponential", "cosine"], default=None, allow_none=True, description="Turn on decay of the learning rate.", @@ -99,6 +99,32 @@ class LRSchedulerConfig(schema_utils.BaseMarshmallowConfig, ABC): parameter_metadata=TRAINER_METADATA[MODEL_ECD]["learning_rate_scheduler"]["reduce_eval_split"], ) + # Parameters for CosineAnnealingWarmRestarts scheduler + + t_0: int = schema_utils.PositiveInteger( + default=None, + allow_none=True, + description="Number of steps before the first restart for cosine annealing decay. If not specified, it" + " will be set to `steps_per_checkpoint`.", + parameter_metadata=TRAINER_METADATA[MODEL_ECD]["learning_rate_scheduler"]["t_0"], + ) + + t_mult: int = schema_utils.PositiveInteger( + default=1, + description="Period multiplier after each restart for cosine annealing decay. Defaults to 1, i.e.," + " restart every `t_0` steps. If set to a larger value, the period between restarts increases by that" + " multiplier. For e.g., if t_mult is 2, then the periods would be: t_0, 2*t_0, 2^2*t_0, 2^3*t_0, etc.", + parameter_metadata=TRAINER_METADATA[MODEL_ECD]["learning_rate_scheduler"]["t_mult"], + ) + + eta_min: float = schema_utils.FloatRange( + default=0, + min=0, + max=1, + description="Minimum learning rate allowed for cosine annealing decay. Default: 0.", + parameter_metadata=TRAINER_METADATA[MODEL_ECD]["learning_rate_scheduler"]["eta_min"], + ) + # TODO(travis): too much boilerplate here, we should find a way to abstract all this and only require specifying the # minimal amount needed for the new config object. diff --git a/ludwig/schema/metadata/configs/trainer.yaml b/ludwig/schema/metadata/configs/trainer.yaml index f97bf6b7905..59264676a09 100644 --- a/ludwig/schema/metadata/configs/trainer.yaml +++ b/ludwig/schema/metadata/configs/trainer.yaml @@ -520,7 +520,10 @@ ecd: suggested_values_reasoning: Starting with exponential decay is a safe place to start, as it is a "softer" decrease in the learning rate over time, as compared with linear, which is more steep after the initial drop. Linear decay is - most useful when the risk of catastrophic forgetting is very high (e.g, for fine-tuning pretrained models). + most useful when the risk of catastrophic forgetting is very high (e.g, for fine-tuning pretrained + models). Cosine annealing is a type of learning rate schedule that has the effect of starting with a + large learning rate that is relatively rapidly decreased to a minimum value before being increased + rapidly again. The resetting of the learning rate acts like a simulated restart of the learning process. If you observe your loss curves shooting up (even on the training set) in later epochs, increasing the decay rate may help mitigate this effect. ui_display_name: Decay @@ -600,6 +603,15 @@ ecd: reduce_eval_split: expected_impact: 1 ui_display_name: Reduce Eval Split + t_0: + expected_impact: 1 + ui_display_name: T_0 + t_mult: + expected_impact: 1 + ui_display_name: T_mult + eta_min: + expected_impact: 1 + ui_display_name: Eta Min gbm: learning_rate: commonly_used: true diff --git a/tests/ludwig/modules/test_lr_scheduler.py b/tests/ludwig/modules/test_lr_scheduler.py index e19e786e722..8ac19e606f4 100644 --- a/tests/ludwig/modules/test_lr_scheduler.py +++ b/tests/ludwig/modules/test_lr_scheduler.py @@ -1,3 +1,5 @@ +import math + import numpy as np from torch.optim import SGD @@ -33,6 +35,11 @@ def test_lr_scheduler_warmup_decay(): exp_scheduler = LRScheduler(config=exp_config, optimizer=exp_optimizer) exp_scheduler.reset(steps_per_checkpoint, total_steps) + cosine_optimizer = SGD(module.parameters(), lr=base_lr) + cosine_config = LRSchedulerConfig(warmup_fraction=warmup_fraction, decay="cosine", t_0=steps_per_checkpoint) + cosine_scheduler = LRScheduler(config=cosine_config, optimizer=cosine_optimizer) + cosine_scheduler.reset(steps_per_checkpoint, total_steps) + warmup_steps = total_steps * warmup_fraction for i in range(total_steps): # Offset by 1 @@ -48,17 +55,25 @@ def test_lr_scheduler_warmup_decay(): exp_scheduler.step() exp_lr = exp_optimizer.param_groups[0]["lr"] + cosine_scheduler.step() + cosine_lr = cosine_optimizer.param_groups[0]["lr"] + if step < warmup_steps: assert linear_lr == exp_lr, f"step: {step}" + assert linear_lr == cosine_lr, f"step: {step}" assert linear_lr < base_lr, f"step: {step}" elif step == warmup_steps: assert linear_lr == base_lr, f"step: {step}" + assert cosine_lr == base_lr, f"step: {step}" assert exp_lr < base_lr, f"step: {step}" else: assert linear_lr < base_lr, f"step: {step}" assert exp_lr < base_lr, f"step: {step}" + assert cosine_lr <= base_lr, f"step: {step}" assert linear_lr < exp_lr + assert exp_lr < cosine_lr + assert cosine_lr == base_lr def test_lr_scheduler_reduce_on_plateau(): @@ -119,6 +134,75 @@ def test_lr_scheduler_reduce_on_plateau(): assert np.isclose(lr, 0.001) +def test_lr_scheduler_cosine_decay_fixed_period(): + total_steps = 10000 + steps_per_checkpoint = 1000 + base_lr = 1.0 + + module = NumberInputFeature(NumberInputFeatureConfig(name="num1", encoder=DenseEncoderConfig())) + + optimizer = SGD(module.parameters(), lr=base_lr) + config = LRSchedulerConfig(decay="cosine", t_0=steps_per_checkpoint, decay_rate=0, reduce_on_plateau=0) + scheduler = LRScheduler(config=config, optimizer=optimizer) + scheduler.reset(steps_per_checkpoint, total_steps) + + curr_lr = base_lr + prev_lr = base_lr + num_restarts = 0 + for step in range(total_steps + 1): + # Cosine annealing formula + expected_lr = base_lr * 0.5 * (1 + math.cos(math.pi * (step % steps_per_checkpoint) / steps_per_checkpoint)) + assert np.isclose(curr_lr, expected_lr), f"step: {step}" + + if prev_lr < curr_lr: + # Since Cosine decay is periodic, we should see the learning rate + # decrease and then increase again. + num_restarts += 1 + + prev_lr = curr_lr + scheduler.step() + + curr_lr = optimizer.param_groups[0]["lr"] + + assert num_restarts == 10, f"num_restarts: {num_restarts}" + + +def test_lr_scheduler_cosine_decay_increasing_period(): + total_steps = 20000 + steps_per_checkpoint = 1000 + base_lr = 1.0 + + module = NumberInputFeature(NumberInputFeatureConfig(name="num1", encoder=DenseEncoderConfig())) + + optimizer = SGD(module.parameters(), lr=base_lr) + config = LRSchedulerConfig( + decay="cosine", + t_0=steps_per_checkpoint, + t_mult=2, + decay_rate=0, + reduce_on_plateau=0, + ) + scheduler = LRScheduler(config=config, optimizer=optimizer) + scheduler.reset(steps_per_checkpoint, total_steps) + + curr_lr = base_lr + prev_lr = base_lr + num_restarts = 0 + for _ in range(total_steps + 1): + if prev_lr < curr_lr: + # Since Cosine decay is periodic, we should see the learning rate + # decrease and then increase again. + num_restarts += 1 + + prev_lr = curr_lr + scheduler.step() + + curr_lr = optimizer.param_groups[0]["lr"] + + # 1000, 3000, 6000, 12000, 24000 (but we stop at 20000) + assert num_restarts == 4, f"num_restarts: {num_restarts}" + + def test_lr_scheduler_save_load(): steps_per_checkpoint = 10 total_steps = 100