Skip to content

Commit

Permalink
Feature/sg 573 Integrate new EMA decay schedules (#647)
Browse files Browse the repository at this point in the history
* Refactored EMA instantiation

---------

Co-authored-by: Ofri Masad <ofrimasad@users.noreply.github.com>
  • Loading branch information
BloodAxe and ofrimasad authored Jan 30, 2023
1 parent 9a59706 commit 69a82bc
Show file tree
Hide file tree
Showing 11 changed files with 225 additions and 92 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ ema: True
ema_params:
decay: 0.9999
beta: 15
exp_activation: True
decay_type: exp

train_metrics_list:
- PixelAccuracy:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ criterion_params: {} # when `loss` is one of SuperGradient's built in options, i
ema: False # whether to use Model Exponential Moving Average
ema_params: # parameters for the ema model.
decay: 0.9999
decay_type: exp
beta: 15
exp_activation: True


train_metrics_list: [] # Metrics to log during training. For more information on torchmetrics see https://torchmetrics.rtfd.io/en/latest/.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ optimizer_params:

ema: True
ema_params:
exp_activation: False
decay: 0.9999
decay_type: constant

loss: cross_entropy
criterion_params:
Expand All @@ -42,4 +42,3 @@ valid_metrics_list: # metrics for evaluation
- Top5

_convert_: all

Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ optimizer_params:

ema: True
ema_params:
exp_activation: False
decay_type: constant
decay: 0.9999

loss: cross_entropy
Expand Down
39 changes: 19 additions & 20 deletions src/super_gradients/training/kd_trainer/kd_trainer.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,14 @@
from typing import Union, Dict, Mapping, Any

import hydra
import torch.nn
from omegaconf import DictConfig, OmegaConf
from torch.utils.data import DataLoader

from super_gradients.training.utils.distributed_training_utils import setup_device
from super_gradients.common import MultiGPUMode
from super_gradients.training.dataloaders import dataloaders
from super_gradients.training.models import SgModule
from super_gradients.training.models.all_architectures import KD_ARCHITECTURES
from super_gradients.training.models.kd_modules.kd_module import KDModule
from super_gradients.training.sg_trainer import Trainer
from typing import Union, Dict
from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.training import utils as core_utils, models
from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
from super_gradients.training.utils import get_param, HpmStruct
from super_gradients.training.utils.checkpoint_utils import read_ckpt_state_dict, load_checkpoint_to_model
from super_gradients.training.dataloaders import dataloaders
from super_gradients.training.exceptions.kd_trainer_exceptions import (
ArchitectureKwargsException,
UnsupportedKDArchitectureException,
Expand All @@ -24,7 +17,15 @@
TeacherKnowledgeException,
UndefinedNumClassesException,
)
from super_gradients.training.models import SgModule
from super_gradients.training.models.all_architectures import KD_ARCHITECTURES
from super_gradients.training.models.kd_modules.kd_module import KDModule
from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
from super_gradients.training.sg_trainer import Trainer
from super_gradients.training.utils import get_param, HpmStruct
from super_gradients.training.utils.callbacks import KDModelMetricsUpdateCallback
from super_gradients.training.utils.checkpoint_utils import read_ckpt_state_dict, load_checkpoint_to_model
from super_gradients.training.utils.distributed_training_utils import setup_device
from super_gradients.training.utils.ema import KDModelEMA

logger = get_logger(__name__)
Expand Down Expand Up @@ -255,17 +256,15 @@ def _get_hyper_param_config(self):
)
return hyper_param_config

def _instantiate_ema_model(self, decay: float = 0.9999, beta: float = 15, exp_activation: bool = True) -> KDModelEMA:
"""Instantiate KD ema model for KDModule.
If the model is of class KDModule, the instance will be adapted to work on knowledge distillation.
:param decay: the maximum decay value. as the training process advances, the decay will climb towards
this value until the EMA_t+1 = EMA_t * decay + TRAINING_MODEL * (1- decay)
:param beta: the exponent coefficient. The higher the beta, the sooner in the training the decay will
saturate to its final value. beta=15 is ~40% of the training process.
:param exp_activation:
def _instantiate_ema_model(self, ema_params: Mapping[str, Any]) -> KDModelEMA:
"""Instantiate ema model for standard SgModule.
:param decay_type: (str) The decay climb schedule. See EMA_DECAY_FUNCTIONS for more details.
:param decay: The maximum decay value. As the training process advances, the decay will climb towards this value
according to decay_type schedule. See EMA_DECAY_FUNCTIONS for more details.
:param kwargs: Additional parameters for the decay function. See EMA_DECAY_FUNCTIONS for more details.
"""
return KDModelEMA(self.net, decay, beta, exp_activation)
logger.info(f"Using EMA with params {ema_params}")
return KDModelEMA.from_params(self.net, **ema_params)

def _save_best_checkpoint(self, epoch, state):
"""
Expand Down
103 changes: 51 additions & 52 deletions src/super_gradients/training/sg_trainer/sg_trainer.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,70 @@
import inspect
import os
from copy import deepcopy
from typing import Union, Tuple, Mapping, Dict
from pathlib import Path
from typing import Union, Tuple, Mapping, Dict, Any

import hydra
import numpy as np
import torch
import hydra
from omegaconf import DictConfig
from omegaconf import OmegaConf
from piptools.scripts.sync import _get_installed_distributions
from torch import nn
from torch.utils.data import DataLoader, SequentialSampler
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import DataLoader, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from torchmetrics import MetricCollection
from tqdm import tqdm
from piptools.scripts.sync import _get_installed_distributions

from torch.utils.data.distributed import DistributedSampler

from super_gradients.training.datasets.samplers import InfiniteSampler, RepeatAugSampler

from super_gradients.common.factories.callbacks_factory import CallbacksFactory
from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.common.data_types.enum import MultiGPUMode, StrictLoad, EvaluationType
from super_gradients.training.models.all_architectures import ARCHITECTURES
from super_gradients.common.decorators.factory_decorator import resolve_param
from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.common.environment.device_utils import device_config
from super_gradients.common.factories.callbacks_factory import CallbacksFactory
from super_gradients.common.factories.list_factory import ListFactory
from super_gradients.common.factories.losses_factory import LossesFactory
from super_gradients.common.factories.metrics_factory import MetricsFactory
from super_gradients.common.factories.pre_launch_callbacks_factory import PreLaunchCallbacksFactory
from super_gradients.common.sg_loggers import SG_LOGGERS
from super_gradients.common.sg_loggers.abstract_sg_logger import AbstractSGLogger
from super_gradients.common.sg_loggers.base_sg_logger import BaseSGLogger
from super_gradients.training import utils as core_utils, models, dataloaders
from super_gradients.training.models import SgModule
from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
from super_gradients.training.utils import sg_trainer_utils, get_param
from super_gradients.training.utils.sg_trainer_utils import MonitoredValue, log_main_training_params
from super_gradients.training.datasets.datasets_utils import DatasetStatisticsTensorboardLogger
from super_gradients.training.datasets.samplers import InfiniteSampler, RepeatAugSampler
from super_gradients.training.exceptions.sg_trainer_exceptions import UnsupportedOptimizerFormat
from super_gradients.training.metrics import Accuracy, Top5
from super_gradients.training.metrics.metric_utils import (
get_metrics_titles,
get_metrics_results_tuple,
get_logging_values,
get_metrics_dict,
get_train_loop_description_dict,
)
from super_gradients.training.models import SgModule
from super_gradients.training.models.all_architectures import ARCHITECTURES
from super_gradients.training.params import TrainingParams
from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
from super_gradients.training.utils import HpmStruct
from super_gradients.training.utils import random_seed
from super_gradients.training.utils import sg_trainer_utils, get_param
from super_gradients.training.utils.callbacks import (
CallbackHandler,
Phase,
LR_SCHEDULERS_CLS_DICT,
PhaseContext,
MetricsUpdateCallback,
LR_WARMUP_CLS_DICT,
ContextSgMethods,
LRCallbackBase,
)
from super_gradients.training.utils.checkpoint_utils import (
get_ckpt_local_path,
read_ckpt_state_dict,
load_checkpoint_to_model,
load_pretrained_weights,
get_checkpoints_dir_path,
)
from super_gradients.training.utils.distributed_training_utils import (
MultiGPUModeAutocastWrapper,
reduce_results_tuple_for_ddp,
Expand All @@ -59,34 +80,11 @@
DDPNotSetupException,
)
from super_gradients.training.utils.ema import ModelEMA
from super_gradients.training.utils.hydra_utils import load_experiment_cfg, add_params_to_cfg
from super_gradients.training.utils.optimizer_utils import build_optimizer
from super_gradients.training.utils.sg_trainer_utils import MonitoredValue, log_main_training_params
from super_gradients.training.utils.utils import fuzzy_idx_in_list
from super_gradients.training.utils.weight_averaging_utils import ModelWeightAveraging
from super_gradients.training.metrics import Accuracy, Top5
from super_gradients.training.utils import random_seed
from super_gradients.training.utils.checkpoint_utils import (
get_ckpt_local_path,
read_ckpt_state_dict,
load_checkpoint_to_model,
load_pretrained_weights,
get_checkpoints_dir_path,
)
from super_gradients.training.datasets.datasets_utils import DatasetStatisticsTensorboardLogger
from super_gradients.training.utils.callbacks import (
CallbackHandler,
Phase,
LR_SCHEDULERS_CLS_DICT,
PhaseContext,
MetricsUpdateCallback,
LR_WARMUP_CLS_DICT,
ContextSgMethods,
LRCallbackBase,
)
from super_gradients.common.environment.device_utils import device_config
from super_gradients.training.utils import HpmStruct
from super_gradients.training.utils.hydra_utils import load_experiment_cfg, add_params_to_cfg
from omegaconf import OmegaConf
from super_gradients.common.factories.pre_launch_callbacks_factory import PreLaunchCallbacksFactory

logger = get_logger(__name__)

Expand Down Expand Up @@ -547,9 +545,11 @@ def _backward_step(self, loss: torch.Tensor, epoch: int, batch_idx: int, context
self.phase_callback_handler.on_train_batch_backward_end(context)

# ACCUMULATE GRADIENT FOR X BATCHES BEFORE OPTIMIZING
integrated_batches_num = batch_idx + len(self.train_loader) * epoch + 1
local_step = batch_idx + 1
global_step = local_step + len(self.train_loader) * epoch
total_steps = len(self.train_loader) * self.max_epochs

if integrated_batches_num % self.batch_accumulate == 0:
if global_step % self.batch_accumulate == 0:
self.phase_callback_handler.on_train_batch_gradient_step_start(context)

# APPLY GRADIENT CLIPPING IF REQUIRED
Expand All @@ -563,7 +563,7 @@ def _backward_step(self, loss: torch.Tensor, epoch: int, batch_idx: int, context

self.optimizer.zero_grad()
if self.ema:
self.ema_model.update(self.net, integrated_batches_num / (len(self.train_loader) * self.max_epochs))
self.ema_model.update(self.net, step=global_step, total_steps=total_steps)

# RUN PHASE CALLBACKS
self.phase_callback_handler.on_train_batch_gradient_step_end(context)
Expand Down Expand Up @@ -1083,9 +1083,7 @@ def forward(self, inputs, targets):
num_batches = len(self.train_loader)

if self.ema:
ema_params = self.training_params.ema_params
logger.info(f"Using EMA with params {ema_params}")
self.ema_model = self._instantiate_ema_model(**ema_params)
self.ema_model = self._instantiate_ema_model(self.training_params.ema_params)
self.ema_model.updates = self.start_epoch * num_batches // self.batch_accumulate
if self.load_checkpoint:
if "ema_net" in self.checkpoint.keys():
Expand Down Expand Up @@ -1903,14 +1901,15 @@ def _instantiate_net(

return net

def _instantiate_ema_model(self, decay: float = 0.9999, beta: float = 15, exp_activation: bool = True) -> ModelEMA:
def _instantiate_ema_model(self, ema_params: Mapping[str, Any]) -> ModelEMA:
"""Instantiate ema model for standard SgModule.
:param decay: the maximum decay value. as the training process advances, the decay will climb towards this value
until the EMA_t+1 = EMA_t * decay + TRAINING_MODEL * (1- decay)
:param beta: the exponent coefficient. The higher the beta, the sooner in the training the decay will saturate to
its final value. beta=15 is ~40% of the training process.
:param decay_type: (str) The decay climb schedule. See EMA_DECAY_FUNCTIONS for more details.
:param decay: The maximum decay value. As the training process advances, the decay will climb towards this value
according to decay_type schedule. See EMA_DECAY_FUNCTIONS for more details.
:param kwargs: Additional parameters for the decay function. See EMA_DECAY_FUNCTIONS for more details.
"""
return ModelEMA(self.net, decay, beta, exp_activation)
logger.info(f"Using EMA with params {ema_params}")
return ModelEMA.from_params(self.net, **ema_params)

@property
def get_net(self):
Expand Down
Loading

0 comments on commit 69a82bc

Please sign in to comment.