Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactored scheduler callbacks (epoch-based/step-based warmup) #568

Merged
merged 44 commits into from
Jan 20, 2023
Merged
Show file tree
Hide file tree
Changes from 43 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
44d3d85
Refactored scheduler callbacks to have clear names & updated config f…
BloodAxe Dec 13, 2022
bcbbf49
Merge branch 'master' into feature/SG-525-step-based-warmup
BloodAxe Dec 13, 2022
25b2201
Merge branch 'master' into feature/SG-525-step-based-warmup
shaydeci Dec 13, 2022
d584c23
Merge branch 'master' into feature/SG-525-step-based-warmup
BloodAxe Jan 9, 2023
0513b72
New callbacks API
BloodAxe Jan 9, 2023
4427ac6
New callbacks API
BloodAxe Jan 9, 2023
f1b3ebf
New callbacks API
BloodAxe Jan 9, 2023
03bfd97
Merge branch 'master' into feature/SG-581
BloodAxe Jan 9, 2023
4a7f6e7
Merge branch 'master' into feature/SG-581
BloodAxe Jan 12, 2023
727bd88
Fix imports
BloodAxe Jan 12, 2023
f067964
Uncomment sanity check
BloodAxe Jan 12, 2023
23ed76b
Fix callbacks order
BloodAxe Jan 12, 2023
a4bf08d
Fix callbacks order
BloodAxe Jan 12, 2023
ee4e579
Fix callbacks order
BloodAxe Jan 12, 2023
1e55891
Fix wrong import
BloodAxe Jan 12, 2023
e5a19cf
Fix wrong import
BloodAxe Jan 12, 2023
41f5b12
Merge branch 'master' into feature/SG-581
BloodAxe Jan 12, 2023
88c1e55
Call update_context
BloodAxe Jan 12, 2023
1cbef94
Call update_context
BloodAxe Jan 12, 2023
2893e2c
Fix on_train_batch_start timing
BloodAxe Jan 12, 2023
7f2f997
Merge branch 'master' into feature/SG-525-step-based-warmup
BloodAxe Jan 12, 2023
5d178c5
Merge branch 'feature/SG-581' into feature/SG-525-step-based-warmup
BloodAxe Jan 12, 2023
16efeae
Batch-step warmup callback (WIP)
BloodAxe Jan 13, 2023
627f764
Added docs clarifying mapping of new events to old Phase enum values
BloodAxe Jan 13, 2023
6a2d46f
Instantiate context after _prep_for_test
BloodAxe Jan 13, 2023
173f65c
Merge branch 'master' into feature/SG-581
BloodAxe Jan 13, 2023
1664ae3
Merge branch 'feature/SG-581' into feature/SG-525-step-based-warmup
BloodAxe Jan 13, 2023
d671b49
Update callback implementation
BloodAxe Jan 16, 2023
98432fa
Merge branch 'master' into feature/SG-525-step-based-warmup
BloodAxe Jan 16, 2023
d60faf9
replace on_train_batch_gradient_step_start with on_train_batch_start …
BloodAxe Jan 16, 2023
495a35e
Merge branch 'master' into feature/SG-525-step-based-warmup
shaydeci Jan 16, 2023
f8ff62a
Merge branch 'master' into feature/SG-525-step-based-warmup
BloodAxe Jan 18, 2023
d67bbf0
Merge branch 'master' into feature/SG-525-step-based-warmup
BloodAxe Jan 18, 2023
dcf2209
Added check for lr_warmup_epochs = 0
BloodAxe Jan 18, 2023
7c881e3
Merge branch 'master' into feature/SG-525-step-based-warmup
BloodAxe Jan 19, 2023
a5b0216
Merge branch 'master' into feature/SG-525-step-based-warmup
BloodAxe Jan 19, 2023
fd669b7
Added test for warmup + cosine lr scheduler
BloodAxe Jan 19, 2023
1d42791
Merge branch 'master' into feature/SG-525-step-based-warmup
BloodAxe Jan 19, 2023
a099b93
Added missing documentation for warmup params
BloodAxe Jan 19, 2023
ec92a58
Merge remote-tracking branch 'origin/feature/SG-525-step-based-warmup…
BloodAxe Jan 19, 2023
0e8fb6c
Fix tests
BloodAxe Jan 19, 2023
913df8e
Fix tests
BloodAxe Jan 19, 2023
77c76d9
Merge branch 'master' into feature/SG-525-step-based-warmup
BloodAxe Jan 19, 2023
17809ca
Merge branch 'master' into feature/SG-525-step-based-warmup
shaydeci Jan 19, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/super_gradients/common/object_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ class LRWarmups:
"""Static class to hold all the supported LR Warmup names"""

LINEAR_STEP = "linear_step"
LINEAR_EPOCH_STEP = "linear_epoch_step"
LINEAR_BATCH_STEP = "linear_batch_step"


class Samplers:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ ckpt_name: ckpt_latest.pth # The checkpoint (.pth file) filename in CKPT_ROOT_D
lr_mode: # Learning rate scheduling policy, one of ['step','poly','cosine','function']
lr_schedule_function: # Learning rate scheduling function to be used when `lr_mode` is 'function'.
lr_warmup_epochs: 0 # number of epochs for learning rate warm up - see https://arxiv.org/pdf/1706.02677.pdf (Section 2.2).
lr_warmup_steps: 0 # number of warmup steps (Used when warmup_mode=linear_batch_step)
lr_cooldown_epochs: 0 # epochs to cooldown LR (i.e the last epoch from scheduling view point=max_epochs-cooldown)
warmup_initial_lr: # Initial lr for linear_step. When none is given, initial_lr/(warmup_epochs+1) will be used.
warmup_initial_lr: # Initial lr for linear_epoch_step/linear_batch_step. When none is given, initial_lr/(warmup_epochs+1) will be used.
step_lr_update_freq: # (float) update frequency in epoch units for computing lr_updates when lr_mode=`step`.
cosine_final_lr_ratio: 0.01 # final learning rate ratio (only relevant when `lr_mode`='cosine')
warmup_mode: linear_step # learning rate warmup scheme, currently only 'linear_step' is supported
warmup_mode: linear_epoch_step # learning rate warmup scheme, currently 'linear_epoch_step' and 'linear_batch_step' are supported

lr_updates:
_target_: super_gradients.training.utils.utils.empty_list # This is a workaround to instantiate a list using _target_. If we would instantiate as "lr_updates: []",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ lr_mode: cosine
cosine_final_lr_ratio: 0
lr_warmup_epochs: 1
warmup_initial_lr: 0
warmup_mode: linear_step
warmup_mode: linear_epoch_step
ema: False
loss: cross_entropy
clip_grad_norm: 1
Expand All @@ -27,4 +27,4 @@ valid_metrics_list: # metrics for evaluation
metric_to_watch: Accuracy
greater_metric_to_watch_is_better: True
average_best_models: False
_convert_: all
_convert_: all
1 change: 1 addition & 0 deletions src/super_gradients/training/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

DEFAULT_TRAINING_PARAMS = {
"lr_warmup_epochs": 0,
"lr_warmup_steps": 0,
"lr_cooldown_epochs": 0,
"warmup_initial_lr": None,
"cosine_final_lr_ratio": 0.01,
Expand Down
37 changes: 29 additions & 8 deletions src/super_gradients/training/sg_trainer/sg_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,9 +701,25 @@ def train(

Learning rate scheduling function to be used when `lr_mode` is 'function'.

- `warmup_mode`: Union[str, Type[LRCallbackBase], None]

If not None, define how the learning rate will be increased during the warmup phase.
Currently, only 'warmup_linear_epoch' and `warmup_linear_step` modes are supported.

- `lr_warmup_epochs` : int (default=0)

Number of epochs for learning rate warm up - see https://arxiv.org/pdf/1706.02677.pdf (Section 2.2).
Relevant for `warmup_mode=warmup_linear_epoch`.
When lr_warmup_epochs > 0, the learning rate will be increased linearly from 0 to the `initial_lr`
once per epoch.

- `lr_warmup_steps` : int (default=0)

Number of steps for learning rate warm up - see https://arxiv.org/pdf/1706.02677.pdf (Section 2.2).
Relevant for `warmup_mode=warmup_linear_step`.
When lr_warmup_steps > 0, the learning rate will be increased linearly from 0 to the `initial_lr`
for a total number of steps according to formula: min(lr_warmup_steps, len(train_loader)).
The capping is done to avoid interference of warmup with epoch-based schedulers.

- `cosine_final_lr_ratio` : float (default=0.01)
Final learning rate ratio (only relevant when `lr_mode`='cosine'). The cosine starts from initial_lr and reaches
Expand Down Expand Up @@ -1087,14 +1103,19 @@ def forward(self, inputs, targets):
**self.training_params.to_dict(),
)
)
if self.training_params.lr_warmup_epochs > 0:
warmup_mode = self.training_params.warmup_mode
if isinstance(warmup_mode, str):
warmup_callback_cls = LR_WARMUP_CLS_DICT[warmup_mode]
elif isinstance(warmup_mode, type) and issubclass(warmup_mode, LRCallbackBase):
warmup_callback_cls = warmup_mode
else:
raise RuntimeError("warmup_mode has to be either a name of a mode (str) or a subclass of PhaseCallback")

warmup_mode = self.training_params.warmup_mode
warmup_callback_cls = None
if isinstance(warmup_mode, str):
warmup_callback_cls = LR_WARMUP_CLS_DICT[warmup_mode]
elif isinstance(warmup_mode, type) and issubclass(warmup_mode, LRCallbackBase):
warmup_callback_cls = warmup_mode
elif warmup_mode is not None:
pass
else:
raise RuntimeError("warmup_mode has to be either a name of a mode (str) or a subclass of PhaseCallback")

if warmup_callback_cls is not None:
self.phase_callbacks.append(
warmup_callback_cls(
train_loader_len=len(self.train_loader),
Expand Down
6 changes: 4 additions & 2 deletions src/super_gradients/training/utils/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
ModelConversionCheckCallback,
DeciLabUploadCallback,
LRCallbackBase,
WarmupLRCallback,
EpochStepWarmupLRCallback,
BatchStepLinearWarmupLRCallback,
StepLRCallback,
ExponentialLRCallback,
PolyLRCallback,
Expand Down Expand Up @@ -41,7 +42,8 @@
"ModelConversionCheckCallback",
"DeciLabUploadCallback",
"LRCallbackBase",
"WarmupLRCallback",
"EpochStepWarmupLRCallback",
"BatchStepLinearWarmupLRCallback",
"StepLRCallback",
"ExponentialLRCallback",
"PolyLRCallback",
Expand Down
18 changes: 13 additions & 5 deletions src/super_gradients/training/utils/callbacks/all_callbacks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from super_gradients.common.object_names import Callbacks, LRSchedulers, LRWarmups
from super_gradients.training.datasets.datasets_utils import DetectionMultiscalePrePredictionCallback
from super_gradients.training.utils.callbacks.callbacks import (
DeciLabUploadCallback,
LRCallbackBase,
Expand All @@ -11,12 +12,11 @@
CosineLRCallback,
ExponentialLRCallback,
FunctionLRCallback,
WarmupLRCallback,
EpochStepWarmupLRCallback,
BatchStepLinearWarmupLRCallback,
)

from super_gradients.training.utils.deprecated_utils import wrap_with_warning
from super_gradients.training.utils.early_stopping import EarlyStop
from super_gradients.training.datasets.datasets_utils import DetectionMultiscalePrePredictionCallback


CALLBACKS = {
Callbacks.DECI_LAB_UPLOAD: DeciLabUploadCallback,
Expand All @@ -39,4 +39,12 @@
}


LR_WARMUP_CLS_DICT = {LRWarmups.LINEAR_STEP: WarmupLRCallback}
LR_WARMUP_CLS_DICT = {
LRWarmups.LINEAR_STEP: wrap_with_warning(
EpochStepWarmupLRCallback,
message=f"Parameter {LRWarmups.LINEAR_STEP} has been made deprecated and will be removed in the next SG release. "
f"Please use `{LRWarmups.LINEAR_EPOCH_STEP}` instead.",
),
LRWarmups.LINEAR_EPOCH_STEP: EpochStepWarmupLRCallback,
LRWarmups.LINEAR_BATCH_STEP: BatchStepLinearWarmupLRCallback,
}
109 changes: 96 additions & 13 deletions src/super_gradients/training/utils/callbacks/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import signal
import time
from typing import List
from typing import List, Union

import cv2
import numpy as np
Expand All @@ -13,7 +13,7 @@

from super_gradients.common.environment.env_variables import env_variables
from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.training.utils.callbacks.base_callbacks import PhaseCallback, PhaseContext, Phase
from super_gradients.training.utils.callbacks.base_callbacks import PhaseCallback, PhaseContext, Phase, Callback
from super_gradients.training.utils.detection_utils import DetectionVisualization, DetectionPostPredictionCallback
from super_gradients.training.utils.segmentation_utils import BinarySegmentationVisualization

Expand Down Expand Up @@ -282,25 +282,96 @@ def update_lr(self, optimizer, epoch, batch_idx=None):
param_group["lr"] = self.lr


class WarmupLRCallback(LRCallbackBase):
class EpochStepWarmupLRCallback(LRCallbackBase):
"""
LR scheduling callback for linear step warmup.
LR climbs from warmup_initial_lr with even steps to initial lr. When warmup_initial_lr is None- LR climb starts from
LR scheduling callback for linear step warmup. This scheduler uses a whole epoch as single step.
LR climbs from warmup_initial_lr with even steps to initial lr. When warmup_initial_lr is None - LR climb starts from
initial_lr/(1+warmup_epochs).

"""

def __init__(self, **kwargs):
super(WarmupLRCallback, self).__init__(Phase.TRAIN_EPOCH_START, **kwargs)
super(EpochStepWarmupLRCallback, self).__init__(Phase.TRAIN_EPOCH_START, **kwargs)
self.warmup_initial_lr = self.training_params.warmup_initial_lr or self.initial_lr / (self.training_params.lr_warmup_epochs + 1)
self.warmup_step_size = (self.initial_lr - self.warmup_initial_lr) / self.training_params.lr_warmup_epochs
self.warmup_step_size = (
(self.initial_lr - self.warmup_initial_lr) / self.training_params.lr_warmup_epochs if self.training_params.lr_warmup_epochs > 0 else 0
)

def perform_scheduling(self, context):
self.lr = self.warmup_initial_lr + context.epoch * self.warmup_step_size
self.update_lr(context.optimizer, context.epoch, None)

def is_lr_scheduling_enabled(self, context):
return self.training_params.lr_warmup_epochs >= context.epoch
return self.training_params.lr_warmup_epochs > 0 and self.training_params.lr_warmup_epochs >= context.epoch


class BatchStepLinearWarmupLRCallback(Callback):
"""
LR scheduling callback for linear step warmup on each batch step.
LR climbs from warmup_initial_lr with to initial lr.
"""

def __init__(
self,
warmup_initial_lr: float,
initial_lr: float,
train_loader_len: int,
update_param_groups: bool,
lr_warmup_steps: int,
training_params,
net,
**kwargs,
):
"""

:param warmup_initial_lr: Starting learning rate
:param initial_lr: Target learning rate after warmup
:param train_loader_len: Length of train data loader
:param lr_warmup_steps: Optional. If passed, will use fixed number of warmup steps to warmup LR. Default is None.
:param kwargs:
"""

super(BatchStepLinearWarmupLRCallback, self).__init__()

if lr_warmup_steps > train_loader_len:
logger.warning(
f"Number of warmup steps ({lr_warmup_steps}) is greater than number of steps in epoch ({train_loader_len}). "
f"Warmup steps will be capped to number of steps in epoch to avoid interfering with any pre-epoch LR schedulers."
)

lr_warmup_steps = min(lr_warmup_steps, train_loader_len)
learning_rates = np.linspace(start=warmup_initial_lr, stop=initial_lr, num=lr_warmup_steps, endpoint=True)

self.lr = initial_lr
self.initial_lr = initial_lr
self.update_param_groups = update_param_groups
self.training_params = training_params
self.net = net
self.learning_rates = learning_rates
self.train_loader_len = train_loader_len
self.lr_warmup_steps = lr_warmup_steps

def on_train_batch_start(self, context: PhaseContext) -> None:
global_training_step = context.batch_idx + context.epoch * self.train_loader_len
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved
if global_training_step < self.lr_warmup_steps:
self.lr = float(self.learning_rates[global_training_step])
self.update_lr(context.optimizer, context.epoch, context.batch_idx)

def update_lr(self, optimizer, epoch, batch_idx=None):
"""
Same as in LRCallbackBase
:param optimizer:
:param epoch:
:param batch_idx:
:return:
"""
if self.update_param_groups:
param_groups = self.net.module.update_param_groups(optimizer.param_groups, self.lr, epoch, batch_idx, self.training_params, self.train_loader_len)
optimizer.param_groups = param_groups
else:
# UPDATE THE OPTIMIZERS PARAMETER
for param_group in optimizer.param_groups:
param_group["lr"] = self.lr


class StepLRCallback(LRCallbackBase):
Expand Down Expand Up @@ -388,17 +459,29 @@ def __init__(self, max_epochs, cosine_final_lr_ratio, **kwargs):
def perform_scheduling(self, context):
effective_epoch = context.epoch - self.training_params.lr_warmup_epochs
effective_max_epochs = self.max_epochs - self.training_params.lr_warmup_epochs - self.training_params.lr_cooldown_epochs
current_iter = self.train_loader_len * effective_epoch + context.batch_idx
max_iter = self.train_loader_len * effective_max_epochs
lr = 0.5 * self.initial_lr * (1.0 + math.cos(current_iter / (max_iter + 1) * math.pi))
# the cosine starts from initial_lr and reaches initial_lr * cosine_final_lr_ratio in last epoch
self.lr = lr * (1 - self.cosine_final_lr_ratio) + (self.initial_lr * self.cosine_final_lr_ratio)
current_iter = max(0, self.train_loader_len * effective_epoch + context.batch_idx - self.training_params.lr_warmup_steps)
max_iter = self.train_loader_len * effective_max_epochs - self.training_params.lr_warmup_steps

lr = self.compute_learning_rate(current_iter, max_iter, self.initial_lr, self.cosine_final_lr_ratio)
self.lr = float(lr)
self.update_lr(context.optimizer, context.epoch, context.batch_idx)

def is_lr_scheduling_enabled(self, context):
# Account of per-step warmup
if self.training_params.lr_warmup_steps > 0:
current_step = self.train_loader_len * context.epoch + context.batch_idx
return current_step >= self.training_params.lr_warmup_steps

post_warmup_epochs = self.training_params.max_epochs - self.training_params.lr_cooldown_epochs
return self.training_params.lr_warmup_epochs <= context.epoch < post_warmup_epochs

@classmethod
def compute_learning_rate(cls, step: Union[float, np.ndarray], total_steps: float, initial_lr: float, final_lr_ratio: float):
# the cosine starts from initial_lr and reaches initial_lr * cosine_final_lr_ratio in last epoch

lr = 0.5 * initial_lr * (1.0 + np.cos(step / (total_steps + 1) * math.pi))
return lr * (1 - final_lr_ratio) + (initial_lr * final_lr_ratio)


class FunctionLRCallback(LRCallbackBase):
"""
Expand Down
32 changes: 32 additions & 0 deletions src/super_gradients/training/utils/deprecated_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import Callable, Any

from super_gradients.common.abstractions.abstract_logger import get_logger

logger = get_logger(__name__)


def wrap_with_warning(cls: Callable, message: str) -> Any:
"""
Emits a warning when target class of function is called.

>>> from super_gradients.training.utils.deprecated_utils import wrap_with_warning
>>> from super_gradients.training.utils.callbacks import EpochStepWarmupLRCallback, BatchStepLinearWarmupLRCallback
>>>
>>> LR_WARMUP_CLS_DICT = {
>>> "linear": wrap_with_warning(
>>> EpochStepWarmupLRCallback,
>>> message=f"Parameter `linear` has been made deprecated and will be removed in the next SG release. Please use `linear_epoch` instead",
>>> ),
>>> 'linear_epoch`': EpochStepWarmupLRCallback,
>>> }

:param cls: A class or function to wrap
:param message: A message to emit when this class is called
:return: A factory method that returns wrapped class
"""

def _inner_fn(*args, **kwargs):
logger.warning(message)
return cls(*args, **kwargs)

return _inner_fn
Loading