diff --git a/src/super_gradients/recipes/training_hyperparams/cityscapes_default_train_params.yaml b/src/super_gradients/recipes/training_hyperparams/cityscapes_default_train_params.yaml index 7b3b509e25..47f4baaa03 100644 --- a/src/super_gradients/recipes/training_hyperparams/cityscapes_default_train_params.yaml +++ b/src/super_gradients/recipes/training_hyperparams/cityscapes_default_train_params.yaml @@ -16,7 +16,7 @@ ema: True ema_params: decay: 0.9999 beta: 15 - exp_activation: True + decay_type: exp train_metrics_list: - PixelAccuracy: diff --git a/src/super_gradients/recipes/training_hyperparams/default_train_params.yaml b/src/super_gradients/recipes/training_hyperparams/default_train_params.yaml index 4fbe1b3467..58a59798aa 100644 --- a/src/super_gradients/recipes/training_hyperparams/default_train_params.yaml +++ b/src/super_gradients/recipes/training_hyperparams/default_train_params.yaml @@ -30,8 +30,8 @@ criterion_params: {} # when `loss` is one of SuperGradient's built in options, i ema: False # whether to use Model Exponential Moving Average ema_params: # parameters for the ema model. decay: 0.9999 + decay_type: exp beta: 15 - exp_activation: True train_metrics_list: [] # Metrics to log during training. For more information on torchmetrics see https://torchmetrics.rtfd.io/en/latest/. diff --git a/src/super_gradients/recipes/training_hyperparams/imagenet_efficientnet_train_params.yaml b/src/super_gradients/recipes/training_hyperparams/imagenet_efficientnet_train_params.yaml index 86ced3900d..3f8c1b122c 100644 --- a/src/super_gradients/recipes/training_hyperparams/imagenet_efficientnet_train_params.yaml +++ b/src/super_gradients/recipes/training_hyperparams/imagenet_efficientnet_train_params.yaml @@ -17,8 +17,8 @@ optimizer_params: ema: True ema_params: - exp_activation: False decay: 0.9999 + decay_type: constant loss: cross_entropy criterion_params: @@ -42,4 +42,3 @@ valid_metrics_list: # metrics for evaluation - Top5 _convert_: all - diff --git a/src/super_gradients/recipes/training_hyperparams/imagenet_regnetY_train_params.yaml b/src/super_gradients/recipes/training_hyperparams/imagenet_regnetY_train_params.yaml index 8776d99a2d..ad8d2f498c 100644 --- a/src/super_gradients/recipes/training_hyperparams/imagenet_regnetY_train_params.yaml +++ b/src/super_gradients/recipes/training_hyperparams/imagenet_regnetY_train_params.yaml @@ -17,7 +17,7 @@ optimizer_params: ema: True ema_params: - exp_activation: False + decay_type: constant decay: 0.9999 loss: cross_entropy diff --git a/src/super_gradients/training/kd_trainer/kd_trainer.py b/src/super_gradients/training/kd_trainer/kd_trainer.py index ab003fcac5..90efe91e38 100644 --- a/src/super_gradients/training/kd_trainer/kd_trainer.py +++ b/src/super_gradients/training/kd_trainer/kd_trainer.py @@ -1,21 +1,14 @@ +from typing import Union, Dict, Mapping, Any + import hydra import torch.nn from omegaconf import DictConfig, OmegaConf from torch.utils.data import DataLoader -from super_gradients.training.utils.distributed_training_utils import setup_device from super_gradients.common import MultiGPUMode -from super_gradients.training.dataloaders import dataloaders -from super_gradients.training.models import SgModule -from super_gradients.training.models.all_architectures import KD_ARCHITECTURES -from super_gradients.training.models.kd_modules.kd_module import KDModule -from super_gradients.training.sg_trainer import Trainer -from typing import Union, Dict from super_gradients.common.abstractions.abstract_logger import get_logger from super_gradients.training import utils as core_utils, models -from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES -from super_gradients.training.utils import get_param, HpmStruct -from super_gradients.training.utils.checkpoint_utils import read_ckpt_state_dict, load_checkpoint_to_model +from super_gradients.training.dataloaders import dataloaders from super_gradients.training.exceptions.kd_trainer_exceptions import ( ArchitectureKwargsException, UnsupportedKDArchitectureException, @@ -24,7 +17,15 @@ TeacherKnowledgeException, UndefinedNumClassesException, ) +from super_gradients.training.models import SgModule +from super_gradients.training.models.all_architectures import KD_ARCHITECTURES +from super_gradients.training.models.kd_modules.kd_module import KDModule +from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES +from super_gradients.training.sg_trainer import Trainer +from super_gradients.training.utils import get_param, HpmStruct from super_gradients.training.utils.callbacks import KDModelMetricsUpdateCallback +from super_gradients.training.utils.checkpoint_utils import read_ckpt_state_dict, load_checkpoint_to_model +from super_gradients.training.utils.distributed_training_utils import setup_device from super_gradients.training.utils.ema import KDModelEMA logger = get_logger(__name__) @@ -255,17 +256,15 @@ def _get_hyper_param_config(self): ) return hyper_param_config - def _instantiate_ema_model(self, decay: float = 0.9999, beta: float = 15, exp_activation: bool = True) -> KDModelEMA: - """Instantiate KD ema model for KDModule. - - If the model is of class KDModule, the instance will be adapted to work on knowledge distillation. - :param decay: the maximum decay value. as the training process advances, the decay will climb towards - this value until the EMA_t+1 = EMA_t * decay + TRAINING_MODEL * (1- decay) - :param beta: the exponent coefficient. The higher the beta, the sooner in the training the decay will - saturate to its final value. beta=15 is ~40% of the training process. - :param exp_activation: + def _instantiate_ema_model(self, ema_params: Mapping[str, Any]) -> KDModelEMA: + """Instantiate ema model for standard SgModule. + :param decay_type: (str) The decay climb schedule. See EMA_DECAY_FUNCTIONS for more details. + :param decay: The maximum decay value. As the training process advances, the decay will climb towards this value + according to decay_type schedule. See EMA_DECAY_FUNCTIONS for more details. + :param kwargs: Additional parameters for the decay function. See EMA_DECAY_FUNCTIONS for more details. """ - return KDModelEMA(self.net, decay, beta, exp_activation) + logger.info(f"Using EMA with params {ema_params}") + return KDModelEMA.from_params(self.net, **ema_params) def _save_best_checkpoint(self, epoch, state): """ diff --git a/src/super_gradients/training/sg_trainer/sg_trainer.py b/src/super_gradients/training/sg_trainer/sg_trainer.py index 349f31693e..d30464a96c 100755 --- a/src/super_gradients/training/sg_trainer/sg_trainer.py +++ b/src/super_gradients/training/sg_trainer/sg_trainer.py @@ -1,41 +1,39 @@ import inspect import os from copy import deepcopy -from typing import Union, Tuple, Mapping, Dict from pathlib import Path +from typing import Union, Tuple, Mapping, Dict, Any +import hydra import numpy as np import torch -import hydra from omegaconf import DictConfig +from omegaconf import OmegaConf +from piptools.scripts.sync import _get_installed_distributions from torch import nn -from torch.utils.data import DataLoader, SequentialSampler from torch.cuda.amp import GradScaler, autocast +from torch.utils.data import DataLoader, SequentialSampler +from torch.utils.data.distributed import DistributedSampler from torchmetrics import MetricCollection from tqdm import tqdm -from piptools.scripts.sync import _get_installed_distributions - -from torch.utils.data.distributed import DistributedSampler -from super_gradients.training.datasets.samplers import InfiniteSampler, RepeatAugSampler - -from super_gradients.common.factories.callbacks_factory import CallbacksFactory +from super_gradients.common.abstractions.abstract_logger import get_logger from super_gradients.common.data_types.enum import MultiGPUMode, StrictLoad, EvaluationType -from super_gradients.training.models.all_architectures import ARCHITECTURES from super_gradients.common.decorators.factory_decorator import resolve_param -from super_gradients.common.abstractions.abstract_logger import get_logger +from super_gradients.common.environment.device_utils import device_config +from super_gradients.common.factories.callbacks_factory import CallbacksFactory from super_gradients.common.factories.list_factory import ListFactory from super_gradients.common.factories.losses_factory import LossesFactory from super_gradients.common.factories.metrics_factory import MetricsFactory +from super_gradients.common.factories.pre_launch_callbacks_factory import PreLaunchCallbacksFactory from super_gradients.common.sg_loggers import SG_LOGGERS from super_gradients.common.sg_loggers.abstract_sg_logger import AbstractSGLogger from super_gradients.common.sg_loggers.base_sg_logger import BaseSGLogger from super_gradients.training import utils as core_utils, models, dataloaders -from super_gradients.training.models import SgModule -from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES -from super_gradients.training.utils import sg_trainer_utils, get_param -from super_gradients.training.utils.sg_trainer_utils import MonitoredValue, log_main_training_params +from super_gradients.training.datasets.datasets_utils import DatasetStatisticsTensorboardLogger +from super_gradients.training.datasets.samplers import InfiniteSampler, RepeatAugSampler from super_gradients.training.exceptions.sg_trainer_exceptions import UnsupportedOptimizerFormat +from super_gradients.training.metrics import Accuracy, Top5 from super_gradients.training.metrics.metric_utils import ( get_metrics_titles, get_metrics_results_tuple, @@ -43,7 +41,30 @@ get_metrics_dict, get_train_loop_description_dict, ) +from super_gradients.training.models import SgModule +from super_gradients.training.models.all_architectures import ARCHITECTURES from super_gradients.training.params import TrainingParams +from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES +from super_gradients.training.utils import HpmStruct +from super_gradients.training.utils import random_seed +from super_gradients.training.utils import sg_trainer_utils, get_param +from super_gradients.training.utils.callbacks import ( + CallbackHandler, + Phase, + LR_SCHEDULERS_CLS_DICT, + PhaseContext, + MetricsUpdateCallback, + LR_WARMUP_CLS_DICT, + ContextSgMethods, + LRCallbackBase, +) +from super_gradients.training.utils.checkpoint_utils import ( + get_ckpt_local_path, + read_ckpt_state_dict, + load_checkpoint_to_model, + load_pretrained_weights, + get_checkpoints_dir_path, +) from super_gradients.training.utils.distributed_training_utils import ( MultiGPUModeAutocastWrapper, reduce_results_tuple_for_ddp, @@ -59,34 +80,11 @@ DDPNotSetupException, ) from super_gradients.training.utils.ema import ModelEMA +from super_gradients.training.utils.hydra_utils import load_experiment_cfg, add_params_to_cfg from super_gradients.training.utils.optimizer_utils import build_optimizer +from super_gradients.training.utils.sg_trainer_utils import MonitoredValue, log_main_training_params from super_gradients.training.utils.utils import fuzzy_idx_in_list from super_gradients.training.utils.weight_averaging_utils import ModelWeightAveraging -from super_gradients.training.metrics import Accuracy, Top5 -from super_gradients.training.utils import random_seed -from super_gradients.training.utils.checkpoint_utils import ( - get_ckpt_local_path, - read_ckpt_state_dict, - load_checkpoint_to_model, - load_pretrained_weights, - get_checkpoints_dir_path, -) -from super_gradients.training.datasets.datasets_utils import DatasetStatisticsTensorboardLogger -from super_gradients.training.utils.callbacks import ( - CallbackHandler, - Phase, - LR_SCHEDULERS_CLS_DICT, - PhaseContext, - MetricsUpdateCallback, - LR_WARMUP_CLS_DICT, - ContextSgMethods, - LRCallbackBase, -) -from super_gradients.common.environment.device_utils import device_config -from super_gradients.training.utils import HpmStruct -from super_gradients.training.utils.hydra_utils import load_experiment_cfg, add_params_to_cfg -from omegaconf import OmegaConf -from super_gradients.common.factories.pre_launch_callbacks_factory import PreLaunchCallbacksFactory logger = get_logger(__name__) @@ -547,9 +545,11 @@ def _backward_step(self, loss: torch.Tensor, epoch: int, batch_idx: int, context self.phase_callback_handler.on_train_batch_backward_end(context) # ACCUMULATE GRADIENT FOR X BATCHES BEFORE OPTIMIZING - integrated_batches_num = batch_idx + len(self.train_loader) * epoch + 1 + local_step = batch_idx + 1 + global_step = local_step + len(self.train_loader) * epoch + total_steps = len(self.train_loader) * self.max_epochs - if integrated_batches_num % self.batch_accumulate == 0: + if global_step % self.batch_accumulate == 0: self.phase_callback_handler.on_train_batch_gradient_step_start(context) # APPLY GRADIENT CLIPPING IF REQUIRED @@ -563,7 +563,7 @@ def _backward_step(self, loss: torch.Tensor, epoch: int, batch_idx: int, context self.optimizer.zero_grad() if self.ema: - self.ema_model.update(self.net, integrated_batches_num / (len(self.train_loader) * self.max_epochs)) + self.ema_model.update(self.net, step=global_step, total_steps=total_steps) # RUN PHASE CALLBACKS self.phase_callback_handler.on_train_batch_gradient_step_end(context) @@ -1083,9 +1083,7 @@ def forward(self, inputs, targets): num_batches = len(self.train_loader) if self.ema: - ema_params = self.training_params.ema_params - logger.info(f"Using EMA with params {ema_params}") - self.ema_model = self._instantiate_ema_model(**ema_params) + self.ema_model = self._instantiate_ema_model(self.training_params.ema_params) self.ema_model.updates = self.start_epoch * num_batches // self.batch_accumulate if self.load_checkpoint: if "ema_net" in self.checkpoint.keys(): @@ -1903,14 +1901,15 @@ def _instantiate_net( return net - def _instantiate_ema_model(self, decay: float = 0.9999, beta: float = 15, exp_activation: bool = True) -> ModelEMA: + def _instantiate_ema_model(self, ema_params: Mapping[str, Any]) -> ModelEMA: """Instantiate ema model for standard SgModule. - :param decay: the maximum decay value. as the training process advances, the decay will climb towards this value - until the EMA_t+1 = EMA_t * decay + TRAINING_MODEL * (1- decay) - :param beta: the exponent coefficient. The higher the beta, the sooner in the training the decay will saturate to - its final value. beta=15 is ~40% of the training process. + :param decay_type: (str) The decay climb schedule. See EMA_DECAY_FUNCTIONS for more details. + :param decay: The maximum decay value. As the training process advances, the decay will climb towards this value + according to decay_type schedule. See EMA_DECAY_FUNCTIONS for more details. + :param kwargs: Additional parameters for the decay function. See EMA_DECAY_FUNCTIONS for more details. """ - return ModelEMA(self.net, decay, beta, exp_activation) + logger.info(f"Using EMA with params {ema_params}") + return ModelEMA.from_params(self.net, **ema_params) @property def get_net(self): diff --git a/src/super_gradients/training/utils/ema.py b/src/super_gradients/training/utils/ema.py index 0b6a4b9ef1..f57dd73494 100755 --- a/src/super_gradients/training/utils/ema.py +++ b/src/super_gradients/training/utils/ema.py @@ -1,4 +1,3 @@ -import math import warnings from copy import deepcopy from typing import Union @@ -6,9 +5,14 @@ import torch from torch import nn +from super_gradients.common.abstractions.abstract_logger import get_logger +from super_gradients.common.exceptions.factory_exceptions import UnknownTypeException from super_gradients.training import utils as core_utils from super_gradients.training.models import SgModule from super_gradients.training.models.kd_modules.kd_module import KDModule +from super_gradients.training.utils.ema_decay_schedules import IDecayFunction, EMA_DECAY_FUNCTIONS + +logger = get_logger(__name__) def copy_attr(a: nn.Module, b: nn.Module, include: Union[list, tuple] = (), exclude: Union[list, tuple] = ()): @@ -30,7 +34,7 @@ class ModelEMA: GPU assignment and distributed training wrappers. """ - def __init__(self, model, decay: float = 0.9999, beta: float = 15, exp_activation: bool = True): + def __init__(self, model, decay: float, decay_function: IDecayFunction): """ Init the EMA :param model: Union[SgModule, nn.Module], the training model to construct the EMA model by @@ -44,10 +48,8 @@ def __init__(self, model, decay: float = 0.9999, beta: float = 15, exp_activatio # Create EMA self.ema = deepcopy(model) self.ema.eval() - if exp_activation: - self.decay_function = lambda x: decay * (1 - math.exp(-x * beta)) # decay exponential ramp (to help early epochs) - else: - self.decay_function = lambda x: decay # always return the same decay factor + self.decay = decay + self.decay_function = decay_function """" we hold a list of model attributes (not wights and biases) which we would like to include in each @@ -65,15 +67,72 @@ def __init__(self, model, decay: float = 0.9999, beta: float = 15, exp_activatio for p in self.ema.module.parameters(): p.requires_grad_(False) - def update(self, model, training_percent: float): + @classmethod + def from_params(cls, model: nn.Module, decay_type: str = None, decay: float = None, **kwargs): + if decay is None: + logger.warning( + "Parameter `decay` is not specified for EMA params. Please specify `decay` parameter explicitly in your config:\n" + "ema: True\n" + "ema_params: \n" + " decay: 0.9999\n" + " decay_type: exp\n" + " beta: 15\n" + "Will default to decay: 0.9999\n" + "In the next major release of SG this warning will become an error." + ) + decay = 0.9999 + + if "exp_activation" in kwargs: + logger.warning( + "Parameter `exp_activation` is deprecated for EMA model. Please update your config to use decay_type: str (constant|exp|threshold) instead:\n" + "ema: True\n" + "ema_params: \n" + " decay: 0.9999\n" + " decay_type: exp # Equivalent to exp_activation: True\n" + " beta: 15\n" + "\n" + "ema: True\n" + "ema_params: \n" + " decay: 0.9999\n" + " decay_type: constant # Equivalent to exp_activation: False\n" + "\n" + "In the next major release of SG this warning will become an error." + ) + decay_type = "exp" if bool(kwargs.pop("exp_activation")) else "constant" + + if decay_type is None: + logger.warning( + "Parameter decay_type is not specified for EMA model. Please specify decay_type parameter explicitly in your config:\n" + "ema: True\n" + "ema_params: \n" + " decay: 0.9999\n" + " decay_type: constant|exp|threshold\n" + "Will default to `exp` decay with beta = 15\n" + "In the next major release of SG this warning will become an error." + ) + decay_type = "exp" + if "beta" not in kwargs: + kwargs["beta"] = 15 + + try: + decay_cls = EMA_DECAY_FUNCTIONS[decay_type] + except KeyError: + raise UnknownTypeException(decay_type, list(EMA_DECAY_FUNCTIONS.keys())) + + decay_function = decay_cls(**kwargs) + return cls(model, decay, decay_function) + + def update(self, model, step: int, total_steps: int): """ Update the state of the EMA model. - :param model: current training model - :param training_percent: the percentage of the training process [0,1]. i.e 0.4 means 40% of the training have passed + + :param model: Current training model + :param step: Current training step + :param total_steps: Total training steps """ # Update EMA parameters with torch.no_grad(): - decay = self.decay_function(training_percent) + decay = self.decay_function(self.decay, step, total_steps) for ema_v, model_v in zip(self.ema.module.state_dict().values(), model.state_dict().values()): if ema_v.dtype.is_floating_point: @@ -101,7 +160,7 @@ class KDModelEMA(ModelEMA): GPU assignment and distributed training wrappers. """ - def __init__(self, kd_model: KDModule, decay: float = 0.9999, beta: float = 15, exp_activation: bool = True): + def __init__(self, kd_model: KDModule, decay: float, decay_function: IDecayFunction): """ Init the EMA :param kd_model: KDModule, the training Knowledge distillation model to construct the EMA model by @@ -113,7 +172,7 @@ def __init__(self, kd_model: KDModule, decay: float = 0.9999, beta: float = 15, its final value. beta=15 is ~40% of the training process. """ # Only work on the student (we don't want to update and to have a duplicate of the teacher) - super().__init__(model=core_utils.WrappedModel(kd_model.module.student), decay=decay, beta=beta, exp_activation=exp_activation) + super().__init__(model=core_utils.WrappedModel(kd_model.module.student), decay=decay, decay_function=decay_function) # Overwrite current ema attribute with combination of the student model EMA (current self.ema) # with already the instantiated teacher, to have the final KD EMA diff --git a/src/super_gradients/training/utils/ema_decay_schedules.py b/src/super_gradients/training/utils/ema_decay_schedules.py new file mode 100644 index 0000000000..06b9308322 --- /dev/null +++ b/src/super_gradients/training/utils/ema_decay_schedules.py @@ -0,0 +1,64 @@ +import math +from abc import abstractmethod + +__all__ = ["IDecayFunction", "ConstantDecay", "ThresholdDecay", "ExpDecay", "EMA_DECAY_FUNCTIONS"] + + +class IDecayFunction: + """ + Interface for EMA decay schedule. The decay schedule is a function of the maximum decay value and training progress. + Usually it gradually increase EMA from to the maximum value. The exact ramp-up schedule is defined by the concrete + implementation. + """ + + @abstractmethod + def __call__(self, decay: float, step: int, total_steps: int) -> float: + """ + + :param decay: The maximum decay value. + :param step: Current training step. The unit-range training percentage can be obtained by `step / total_steps`. + :param total_steps: Total number of training steps. + :return: Computed decay value for a given step. + """ + pass + + +class ConstantDecay(IDecayFunction): + """ + Constant decay schedule. + """ + + def __init__(self, **kwargs): + pass + + def __call__(self, decay: float, step: int, total_steps: int) -> float: + return decay + + +class ThresholdDecay(IDecayFunction): + """ + Gradually increase EMA decay from 0.1 to the maximum value using following formula: min(decay, (1 + step) / (10 + step)) + """ + + def __init__(self, **kwargs): + pass + + def __call__(self, decay: float, step, total_steps: int) -> float: + return min(decay, (1 + step) / (10 + step)) + + +class ExpDecay(IDecayFunction): + """ + Gradually increase EMA decay from 0.1 to the maximum value using following formula: decay * (1 - math.exp(-x * self.beta)) + + """ + + def __init__(self, beta: float, **kwargs): + self.beta = beta + + def __call__(self, decay: float, step, total_steps: int) -> float: + x = step / total_steps + return decay * (1 - math.exp(-x * self.beta)) + + +EMA_DECAY_FUNCTIONS = {"constant": ConstantDecay, "threshold": ThresholdDecay, "exp": ExpDecay} diff --git a/tests/integration_tests/ema_train_integration_test.py b/tests/integration_tests/ema_train_integration_test.py index 2c5ed67340..777e5b319c 100644 --- a/tests/integration_tests/ema_train_integration_test.py +++ b/tests/integration_tests/ema_train_integration_test.py @@ -29,11 +29,21 @@ def _init_model(self) -> None: def tearDownClass(cls) -> None: pass - def test_train(self): + def test_train_exp_decay(self): self._init_model() - self._train({}) + self._train({"decay_type": "exp", "beta": 15, "decay": 0.9999}) + + def test_train_threshold_decay(self): + self._init_model() + self._train({"decay_type": "threshold", "decay": 0.9999}) + + def test_train_constant_decay(self): + self._init_model() + self._train({"decay_type": "constant", "decay": 0.9999}) + + def test_train_with_old_ema_params(self): self._init_model() - self._train({"exp_activation": False}) + self._train({"decay": 0.9999, "exp_activation": True, "beta": 10}) def _train(self, ema_params): training_params = { diff --git a/tests/unit_tests/kd_ema_test.py b/tests/unit_tests/kd_ema_test.py index 09980cd4bb..7dc4cf34ab 100644 --- a/tests/unit_tests/kd_ema_test.py +++ b/tests/unit_tests/kd_ema_test.py @@ -34,6 +34,7 @@ def setUp(cls): "greater_metric_to_watch_is_better": True, "average_best_models": False, "ema": True, + "ema_params": {"decay_type": "constant", "decay": 0.999}, } def test_teacher_ema_not_duplicated(self): diff --git a/tests/unit_tests/kd_trainer_test.py b/tests/unit_tests/kd_trainer_test.py index 06fdf91000..8bb2305f99 100644 --- a/tests/unit_tests/kd_trainer_test.py +++ b/tests/unit_tests/kd_trainer_test.py @@ -133,6 +133,8 @@ def test_load_ckpt_best_for_student_with_ema(self): train_params = self.kd_train_params.copy() train_params["max_epochs"] = 1 train_params["ema"] = True + train_params["ema_params"] = {"decay_type": "constant", "decay": 0.999} + kd_trainer.train( training_params=train_params, student=student,