diff --git a/Makefile b/Makefile index 6b950b02f0..6720befe96 100644 --- a/Makefile +++ b/Makefile @@ -8,9 +8,9 @@ yolo_nas_integration_tests: python -m unittest tests/integration_tests/yolo_nas_integration_test.py recipe_accuracy_tests: - python3.8 src/super_gradients/examples/convert_recipe_example/convert_recipe_example.py --config-name=cifar10_conversion_params experiment_name=shortened_cifar10_resnet_accuracy_test - python3.8 src/super_gradients/train_from_recipe.py --config-name=coco2017_pose_dekr_w32_no_dc experiment_name=shortened_coco2017_pose_dekr_w32_ap_test epochs=1 batch_size=4 val_batch_size=8 training_hyperparams.lr_warmup_steps=0 training_hyperparams.average_best_models=False training_hyperparams.max_train_batches=1000 training_hyperparams.max_valid_batches=100 multi_gpu=DDP num_gpus=4 - python3.8 src/super_gradients/train_from_recipe.py --config-name=cifar10_resnet experiment_name=shortened_cifar10_resnet_accuracy_test epochs=100 training_hyperparams.average_best_models=False multi_gpu=DDP num_gpus=4 - python3.8 src/super_gradients/train_from_recipe.py --config-name=coco2017_yolox experiment_name=shortened_coco2017_yolox_n_map_test epochs=10 architecture=yolox_n training_hyperparams.loss=yolox_fast_loss training_hyperparams.average_best_models=False multi_gpu=DDP num_gpus=4 - python3.8 src/super_gradients/train_from_recipe.py --config-name=cityscapes_regseg48 experiment_name=shortened_cityscapes_regseg48_iou_test epochs=10 training_hyperparams.average_best_models=False multi_gpu=DDP num_gpus=4 + python src/super_gradients/train_from_recipe.py --config-name=coco2017_pose_dekr_w32_no_dc experiment_name=shortened_coco2017_pose_dekr_w32_ap_test epochs=1 batch_size=4 val_batch_size=8 training_hyperparams.lr_warmup_steps=0 training_hyperparams.average_best_models=False training_hyperparams.max_train_batches=1000 training_hyperparams.max_valid_batches=100 multi_gpu=DDP num_gpus=4 + python src/super_gradients/train_from_recipe.py --config-name=cifar10_resnet experiment_name=shortened_cifar10_resnet_accuracy_test epochs=100 training_hyperparams.average_best_models=False multi_gpu=DDP num_gpus=4 + python src/super_gradients/train_from_recipe.py --config-name=coco2017_yolox experiment_name=shortened_coco2017_yolox_n_map_test epochs=10 architecture=yolox_n training_hyperparams.loss=yolox_fast_loss training_hyperparams.average_best_models=False multi_gpu=DDP num_gpus=4 + python src/super_gradients/train_from_recipe.py --config-name=cityscapes_regseg48 experiment_name=shortened_cityscapes_regseg48_iou_test epochs=10 training_hyperparams.average_best_models=False multi_gpu=DDP num_gpus=4 + python src/super_gradients/examples/convert_recipe_example/convert_recipe_example.py --config-name=cifar10_conversion_params experiment_name=shortened_cifar10_resnet_accuracy_test coverage run --source=super_gradients -m unittest tests/deci_core_recipe_test_suite_runner.py diff --git a/src/super_gradients/training/kd_trainer/kd_trainer.py b/src/super_gradients/training/kd_trainer/kd_trainer.py index 97a4c2b1db..3cce2373d2 100644 --- a/src/super_gradients/training/kd_trainer/kd_trainer.py +++ b/src/super_gradients/training/kd_trainer/kd_trainer.py @@ -27,6 +27,7 @@ from super_gradients.training.utils.checkpoint_utils import read_ckpt_state_dict, load_checkpoint_to_model from super_gradients.training.utils.distributed_training_utils import setup_device from super_gradients.training.utils.ema import KDModelEMA +from super_gradients.training.utils.utils import unwrap_model logger = get_logger(__name__) @@ -211,7 +212,7 @@ def _load_checkpoint_to_model(self): the entire KD network following the same logic as in Trainer. """ teacher_checkpoint_path = get_param(self.checkpoint_params, "teacher_checkpoint_path") - teacher_net = self.net.module.teacher + teacher_net = unwrap_model(self.net).teacher if teacher_checkpoint_path is not None: @@ -271,12 +272,12 @@ def _save_best_checkpoint(self, epoch, state): Overrides parent best_ckpt saving to modify the state dict so that we only save the student. """ if self.ema: - best_net = core_utils.WrappedModel(self.ema_model.ema.module.student) + best_net = self.ema_model.ema.student state.pop("ema_net") else: - best_net = core_utils.WrappedModel(self.net.module.student) + best_net = self.net.student - state["net"] = best_net.state_dict() + state["net"] = unwrap_model(best_net).state_dict() self.sg_logger.add_checkpoint(tag=self.ckpt_best_name, state_dict=state, global_step=epoch) def train( diff --git a/src/super_gradients/training/sg_trainer/sg_trainer.py b/src/super_gradients/training/sg_trainer/sg_trainer.py index 1a3b38e51b..0ce5ad0d94 100755 --- a/src/super_gradients/training/sg_trainer/sg_trainer.py +++ b/src/super_gradients/training/sg_trainer/sg_trainer.py @@ -66,7 +66,7 @@ from super_gradients.training.utils.ema import ModelEMA from super_gradients.training.utils.optimizer_utils import build_optimizer from super_gradients.training.utils.sg_trainer_utils import MonitoredValue, log_main_training_params -from super_gradients.training.utils.utils import fuzzy_idx_in_list +from super_gradients.training.utils.utils import fuzzy_idx_in_list, unwrap_model from super_gradients.training.utils.weight_averaging_utils import ModelWeightAveraging from super_gradients.training.metrics import Accuracy, Top5 from super_gradients.training.utils import random_seed @@ -396,9 +396,6 @@ def _net_to_device(self): local_rank = int(device_config.device.split(":")[1]) self.net = torch.nn.parallel.DistributedDataParallel(self.net, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True) - else: - self.net = core_utils.WrappedModel(self.net) - def _train_epoch(self, epoch: int, silent_mode: bool = False) -> tuple: """ train_epoch - A single epoch training procedure @@ -601,7 +598,8 @@ def _save_checkpoint( metric = validation_results_dict[self.metric_to_watch] # BUILD THE state_dict - state = {"net": self.net.state_dict(), "acc": metric, "epoch": epoch} + state = {"net": unwrap_model(self.net).state_dict(), "acc": metric, "epoch": epoch} + if optimizer is not None: state["optimizer_state_dict"] = optimizer.state_dict() @@ -609,7 +607,7 @@ def _save_checkpoint( state["scaler_state_dict"] = self.scaler.state_dict() if self.ema: - state["ema_net"] = self.ema_model.ema.state_dict() + state["ema_net"] = unwrap_model(self.ema_model.ema).state_dict() processing_params = self._get_preprocessing_from_valid_loader() if processing_params is not None: @@ -636,7 +634,7 @@ def _save_checkpoint( logger.info("Best checkpoint overriden: validation " + self.metric_to_watch + ": " + str(metric)) if self.training_params.average_best_models: - net_for_averaging = self.ema_model.ema if self.ema else self.net + net_for_averaging = unwrap_model(self.ema_model.ema if self.ema else self.net) state["net"] = self.model_weight_averaging.get_average_model(net_for_averaging, validation_results_dict=validation_results_dict) self.sg_logger.add_checkpoint(tag=self.average_model_checkpoint_filename, state_dict=state, global_step=epoch) @@ -652,7 +650,7 @@ def _prep_net_for_train(self) -> None: self._net_to_device() # SET THE FLAG FOR DIFFERENT PARAMETER GROUP OPTIMIZER UPDATE - self.update_param_groups = hasattr(self.net.module, "update_param_groups") + self.update_param_groups = hasattr(unwrap_model(self.net), "update_param_groups") self.checkpoint = {} self.strict_load = core_utils.get_param(self.training_params, "resume_strict_load", StrictLoad.ON) @@ -1161,7 +1159,9 @@ def forward(self, inputs, targets): if not self.ddp_silent_mode: if self.training_params.dataset_statistics: dataset_statistics_logger = DatasetStatisticsTensorboardLogger(self.sg_logger) - dataset_statistics_logger.analyze(self.train_loader, all_classes=self.classes, title="Train-set", anchors=self.net.module.arch_params.anchors) + dataset_statistics_logger.analyze( + self.train_loader, all_classes=self.classes, title="Train-set", anchors=unwrap_model(self.net).arch_params.anchors + ) dataset_statistics_logger.analyze(self.valid_loader, all_classes=self.classes, title="val-set") sg_trainer_utils.log_uncaught_exceptions(logger) @@ -1175,7 +1175,7 @@ def forward(self, inputs, targets): if isinstance(self.training_params.optimizer, str) or ( inspect.isclass(self.training_params.optimizer) and issubclass(self.training_params.optimizer, torch.optim.Optimizer) ): - self.optimizer = build_optimizer(net=self.net, lr=self.training_params.initial_lr, training_params=self.training_params) + self.optimizer = build_optimizer(net=unwrap_model(self.net), lr=self.training_params.initial_lr, training_params=self.training_params) elif isinstance(self.training_params.optimizer, torch.optim.Optimizer): self.optimizer = self.training_params.optimizer else: @@ -1248,7 +1248,7 @@ def forward(self, inputs, targets): processing_params = self._get_preprocessing_from_valid_loader() if processing_params is not None: - self.net.module.set_dataset_processing_params(**processing_params) + unwrap_model(self.net).set_dataset_processing_params(**processing_params) try: # HEADERS OF THE TRAINING PROGRESS @@ -1295,7 +1295,7 @@ def forward(self, inputs, targets): num_gpus=get_world_size(), ) - # model switch - we replace self.net.module with the ema model for the testing and saving part + # model switch - we replace self.net with the ema model for the testing and saving part # and then switch it back before the next training epoch if self.ema: self.ema_model.update_attr(self.net) @@ -1355,7 +1355,7 @@ def forward(self, inputs, targets): def _get_preprocessing_from_valid_loader(self) -> Optional[dict]: valid_loader = self.valid_loader - if isinstance(self.net.module, HasPredict) and isinstance(valid_loader.dataset, HasPreprocessingParams): + if isinstance(unwrap_model(self.net), HasPredict) and isinstance(valid_loader.dataset, HasPreprocessingParams): try: return valid_loader.dataset.get_dataset_preprocessing_params() except Exception as e: @@ -1413,7 +1413,7 @@ def _initialize_mixed_precision(self, mixed_precision_enabled: bool): def hook(module, _): module.forward = MultiGPUModeAutocastWrapper(module.forward) - self.net.module.register_forward_pre_hook(hook=hook) + unwrap_model(self.net).register_forward_pre_hook(hook=hook) if self.load_checkpoint: scaler_state_dict = core_utils.get_param(self.checkpoint, "scaler_state_dict") @@ -1439,7 +1439,7 @@ def _validate_final_average_model(self, cleanup_snapshots_pkl_file=False): with wait_for_the_master(local_rank): average_model_sd = read_ckpt_state_dict(average_model_ckpt_path)["net"] - self.net.load_state_dict(average_model_sd) + unwrap_model(self.net).load_state_dict(average_model_sd) # testing the averaged model and save instead of best model if needed averaged_model_results_dict = self._validate_epoch(epoch=self.max_epochs) @@ -1462,7 +1462,7 @@ def get_arch_params(self): @property def get_structure(self): - return self.net.module.structure + return unwrap_model(self.net).structure @property def get_architecture(self): @@ -1494,7 +1494,8 @@ def _re_build_model(self, arch_params={}): if device_config.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL: logger.warning("Warning: distributed training is not supported in re_build_model()") - self.net = torch.nn.DataParallel(self.net, device_ids=get_device_ids()) if device_config.multi_gpu else core_utils.WrappedModel(self.net) + if device_config.multi_gpu == MultiGPUMode.DATA_PARALLEL: + self.net = torch.nn.DataParallel(self.net, device_ids=get_device_ids()) @property def get_module(self): @@ -1635,7 +1636,7 @@ def _initialize_sg_logger_objects(self, additional_configs_to_log: Dict = None): if "model_name" in get_callable_param_names(sg_logger_cls.__init__): if sg_logger_params.get("model_name") is None: # Use the model name used in `models.get(...)` if relevant - sg_logger_params["model_name"] = get_model_name(self.net.module) + sg_logger_params["model_name"] = get_model_name(unwrap_model(self.net)) if sg_logger_params["model_name"] is None: raise ValueError( diff --git a/src/super_gradients/training/utils/__init__.py b/src/super_gradients/training/utils/__init__.py index f75d9c7438..4cf763988d 100755 --- a/src/super_gradients/training/utils/__init__.py +++ b/src/super_gradients/training/utils/__init__.py @@ -1,7 +1,6 @@ from super_gradients.training.utils.utils import ( Timer, HpmStruct, - WrappedModel, convert_to_tensor, get_param, tensor_container_to_device, @@ -17,7 +16,6 @@ __all__ = [ "Timer", "HpmStruct", - "WrappedModel", "convert_to_tensor", "get_param", "tensor_container_to_device", diff --git a/src/super_gradients/training/utils/callbacks/callbacks.py b/src/super_gradients/training/utils/callbacks/callbacks.py index d3d714f87c..1f2991d4e7 100644 --- a/src/super_gradients/training/utils/callbacks/callbacks.py +++ b/src/super_gradients/training/utils/callbacks/callbacks.py @@ -23,7 +23,7 @@ from super_gradients.training.utils.segmentation_utils import BinarySegmentationVisualization from super_gradients.common.environment.ddp_utils import multi_process_safe from super_gradients.common.environment.checkpoints_dir_utils import get_project_checkpoints_dir_path - +from super_gradients.training.utils.utils import unwrap_model logger = get_logger(__name__) @@ -75,7 +75,7 @@ def __init__(self, model_name: str, input_dimensions: Sequence[int], primary_bat self.atol = kwargs.get("atol", 1e-05) def __call__(self, context: PhaseContext): - model = copy.deepcopy(context.net.module) + model = copy.deepcopy(unwrap_model(context.net)) model = model.cpu() model.eval() # Put model into eval mode @@ -204,12 +204,12 @@ def __call__(self, context: PhaseContext) -> None: :param context: Training phase context """ try: - model = copy.deepcopy(context.net) + model = copy.deepcopy(unwrap_model(context.net)) model_state_dict_path = os.path.join(context.ckpt_dir, self.ckpt_name) model_state_dict = torch.load(model_state_dict_path)["net"] model.load_state_dict(state_dict=model_state_dict) - model = model.module.cpu() + model = model.cpu() if hasattr(model, "prep_model_for_conversion"): model.prep_model_for_conversion(input_size=self.input_dimensions) @@ -267,7 +267,9 @@ def perform_scheduling(self, context: PhaseContext): def update_lr(self, optimizer, epoch, batch_idx=None): if self.update_param_groups: - param_groups = self.net.module.update_param_groups(optimizer.param_groups, self.lr, epoch, batch_idx, self.training_params, self.train_loader_len) + param_groups = unwrap_model(self.net).update_param_groups( + optimizer.param_groups, self.lr, epoch, batch_idx, self.training_params, self.train_loader_len + ) optimizer.param_groups = param_groups else: # UPDATE THE OPTIMIZERS PARAMETER @@ -373,7 +375,9 @@ def update_lr(self, optimizer, epoch, batch_idx=None): :return: """ if self.update_param_groups: - param_groups = self.net.module.update_param_groups(optimizer.param_groups, self.lr, epoch, batch_idx, self.training_params, self.train_loader_len) + param_groups = unwrap_model(self.net).update_param_groups( + optimizer.param_groups, self.lr, epoch, batch_idx, self.training_params, self.train_loader_len + ) optimizer.param_groups = param_groups else: # UPDATE THE OPTIMIZERS PARAMETER diff --git a/src/super_gradients/training/utils/checkpoint_utils.py b/src/super_gradients/training/utils/checkpoint_utils.py index 8dc8ebe35c..904a65d5de 100644 --- a/src/super_gradients/training/utils/checkpoint_utils.py +++ b/src/super_gradients/training/utils/checkpoint_utils.py @@ -14,6 +14,7 @@ from super_gradients.module_interfaces import HasPredict from super_gradients.training.pretrained_models import MODEL_URLS from super_gradients.training.utils.distributed_training_utils import get_local_rank, wait_for_the_master +from super_gradients.training.utils.utils import unwrap_model try: from torch.hub import download_url_to_file, load_state_dict_from_url @@ -54,6 +55,13 @@ def adaptive_load_state_dict(net: torch.nn.Module, state_dict: dict, strict: Uni :return: """ state_dict = state_dict["net"] if "net" in state_dict else state_dict + + # This is a backward compatibility fix for checkpoints that were saved with DataParallel/DistributedDataParallel wrapper + # and contains "module." prefix in all keys + # If all keys start with "module.", then we remove it. + if all([key.startswith("module.") for key in state_dict.keys()]): + state_dict = collections.OrderedDict([(key[7:], value) for key, value in state_dict.items()]) + try: strict_bool = strict if isinstance(strict, bool) else strict != StrictLoad.OFF net.load_state_dict(state_dict, strict=strict_bool) @@ -217,6 +225,8 @@ def load_checkpoint_to_model( if isinstance(strict, str): strict = StrictLoad(strict) + net = unwrap_model(net) + if load_backbone and not hasattr(net, "backbone"): raise ValueError("No backbone attribute in net - Can't load backbone weights") @@ -239,7 +249,7 @@ def load_checkpoint_to_model( message_model = "model" if not load_backbone else "model's backbone" logger.info("Successfully loaded " + message_model + " weights from " + ckpt_local_path + message_suffix) - if (isinstance(net, HasPredict) or (hasattr(net, "module") and isinstance(net.module, HasPredict))) and load_processing_params: + if (isinstance(net, HasPredict)) and load_processing_params: if "processing_params" not in checkpoint.keys(): raise ValueError("Can't load processing params - could not find any stored in checkpoint file.") try: @@ -275,7 +285,7 @@ def _yolox_ckpt_solver(ckpt_key, ckpt_val, model_key, model_val): if ( ckpt_val.shape != model_val.shape - and ckpt_key == "module._backbone._modules_list.0.conv.conv.weight" + and (ckpt_key == "module._backbone._modules_list.0.conv.conv.weight" or ckpt_key == "_backbone._modules_list.0.conv.conv.weight") and model_key == "_backbone._modules_list.0.conv.weight" ): model_val.data[:, :, ::2, ::2] = ckpt_val.data[:, :3] diff --git a/src/super_gradients/training/utils/ema.py b/src/super_gradients/training/utils/ema.py index f57dd73494..1710f577cd 100755 --- a/src/super_gradients/training/utils/ema.py +++ b/src/super_gradients/training/utils/ema.py @@ -3,11 +3,11 @@ from typing import Union import torch +from super_gradients.training.utils.utils import unwrap_model from torch import nn from super_gradients.common.abstractions.abstract_logger import get_logger from super_gradients.common.exceptions.factory_exceptions import UnknownTypeException -from super_gradients.training import utils as core_utils from super_gradients.training.models import SgModule from super_gradients.training.models.kd_modules.kd_module import KDModule from super_gradients.training.utils.ema_decay_schedules import IDecayFunction, EMA_DECAY_FUNCTIONS @@ -34,7 +34,7 @@ class ModelEMA: GPU assignment and distributed training wrappers. """ - def __init__(self, model, decay: float, decay_function: IDecayFunction): + def __init__(self, model: nn.Module, decay: float, decay_function: IDecayFunction): """ Init the EMA :param model: Union[SgModule, nn.Module], the training model to construct the EMA model by @@ -46,6 +46,7 @@ def __init__(self, model, decay: float, decay_function: IDecayFunction): its final value. beta=15 is ~40% of the training process. """ # Create EMA + model = unwrap_model(model) self.ema = deepcopy(model) self.ema.eval() self.decay = decay @@ -57,14 +58,14 @@ def __init__(self, model, decay: float, decay_function: IDecayFunction): get_include_attributes and get_exclude_attributes functions. for a nn.Module which is not a SgModule all non-private (not starting with '_') attributes will be updated (and only them). """ - if isinstance(model.module, SgModule): - self.include_attributes = model.module.get_include_attributes() - self.exclude_attributes = model.module.get_exclude_attributes() + if isinstance(model, SgModule): + self.include_attributes = model.get_include_attributes() + self.exclude_attributes = model.get_exclude_attributes() else: warnings.warn("Warning: EMA should be used with SgModule instance. All attributes of the model will be " "included in EMA") self.include_attributes = [] self.exclude_attributes = [] - for p in self.ema.module.parameters(): + for p in self.ema.parameters(): p.requires_grad_(False) @classmethod @@ -131,10 +132,11 @@ def update(self, model, step: int, total_steps: int): :param total_steps: Total training steps """ # Update EMA parameters + model = unwrap_model(model) with torch.no_grad(): decay = self.decay_function(self.decay, step, total_steps) - for ema_v, model_v in zip(self.ema.module.state_dict().values(), model.state_dict().values()): + for ema_v, model_v in zip(self.ema.state_dict().values(), model.state_dict().values()): if ema_v.dtype.is_floating_point: ema_v.copy_(ema_v * decay + (1.0 - decay) * model_v.detach()) @@ -147,7 +149,7 @@ def update_attr(self, model): attributes will be updated (and only them). :param model: the source model """ - copy_attr(self.ema.module, model.module, self.include_attributes, self.exclude_attributes) + copy_attr(self.ema, unwrap_model(model), self.include_attributes, self.exclude_attributes) class KDModelEMA(ModelEMA): @@ -172,15 +174,13 @@ def __init__(self, kd_model: KDModule, decay: float, decay_function: IDecayFunct its final value. beta=15 is ~40% of the training process. """ # Only work on the student (we don't want to update and to have a duplicate of the teacher) - super().__init__(model=core_utils.WrappedModel(kd_model.module.student), decay=decay, decay_function=decay_function) + super().__init__(model=kd_model.student, decay=decay, decay_function=decay_function) # Overwrite current ema attribute with combination of the student model EMA (current self.ema) # with already the instantiated teacher, to have the final KD EMA - self.ema = core_utils.WrappedModel( - KDModule( - arch_params=kd_model.module.arch_params, - student=self.ema.module, - teacher=kd_model.module.teacher, - run_teacher_on_eval=kd_model.module.run_teacher_on_eval, - ) + self.ema = KDModule( + arch_params=kd_model.arch_params, + student=self.ema, + teacher=kd_model.teacher, + run_teacher_on_eval=kd_model.run_teacher_on_eval, ) diff --git a/src/super_gradients/training/utils/optimizer_utils.py b/src/super_gradients/training/utils/optimizer_utils.py index 7844f67a9f..5f47a7f72e 100755 --- a/src/super_gradients/training/utils/optimizer_utils.py +++ b/src/super_gradients/training/utils/optimizer_utils.py @@ -1,7 +1,5 @@ -import torch.optim as optim import torch.nn as nn -from torch.nn.modules.batchnorm import _BatchNorm -from torch.nn.modules.conv import _ConvNd +import torch.optim as optim from super_gradients.common.abstractions.abstract_logger import get_logger from super_gradients.common.factories.optimizers_type_factory import OptimizersTypeFactory from super_gradients.training.params import ( @@ -12,6 +10,9 @@ ) from super_gradients.training.utils import get_param from super_gradients.training.utils.optimizers.rmsprop_tf import RMSpropTF +from super_gradients.training.utils.utils import is_model_wrapped +from torch.nn.modules.batchnorm import _BatchNorm +from torch.nn.modules.conv import _ConvNd logger = get_logger(__name__) @@ -86,6 +87,8 @@ def build_optimizer(net: nn.Module, lr: float, training_params) -> optim.Optimiz :param lr: initial learning rate :param training_params: training_parameters """ + if is_model_wrapped(net): + raise ValueError("Argument net for build_optimizer must be an unwrapped model. " "Please use build_optimizer(unwrap_model(net), ...).") if isinstance(training_params.optimizer, str): optimizer_cls = OptimizersTypeFactory().get(training_params.optimizer) else: @@ -96,14 +99,14 @@ def build_optimizer(net: nn.Module, lr: float, training_params) -> optim.Optimiz weight_decay = get_param(training_params.optimizer_params, "weight_decay", 0.0) # OPTIMIZER PARAM GROUPS ARE SET USING DEFAULT OR MODEL SPECIFIC INIT - if hasattr(net.module, "initialize_param_groups"): + if hasattr(net, "initialize_param_groups"): # INITIALIZE_PARAM_GROUPS MUST RETURN A LIST OF DICTS WITH 'named_params' AND OPTIMIZER's ATTRIBUTES PER GROUP - net_named_params = net.module.initialize_param_groups(lr, training_params) + net_named_params = net.initialize_param_groups(lr, training_params) else: net_named_params = [{"named_params": net.named_parameters()}] if training_params.zero_weight_decay_on_bias_and_bn: - optimizer_training_params = separate_zero_wd_params_groups_for_optimizer(net.module, net_named_params, weight_decay) + optimizer_training_params = separate_zero_wd_params_groups_for_optimizer(net, net_named_params, weight_decay) else: # Overwrite groups to include params instead of named params diff --git a/src/super_gradients/training/utils/utils.py b/src/super_gradients/training/utils/utils.py index 6a6f05a6be..a8caf0d6cc 100755 --- a/src/super_gradients/training/utils/utils.py +++ b/src/super_gradients/training/utils/utils.py @@ -6,6 +6,7 @@ import time import inspect import typing +import warnings from functools import lru_cache, wraps from importlib import import_module from itertools import islice @@ -13,6 +14,7 @@ from pathlib import Path from typing import Mapping, Optional, Tuple, Union, List, Dict, Any, Iterable from zipfile import ZipFile +from torch.nn.parallel import DistributedDataParallel import numpy as np import torch @@ -80,12 +82,36 @@ def validate(self): class WrappedModel(nn.Module): def __init__(self, module): super(WrappedModel, self).__init__() + warnings.warn( + "WrappedModel is deprecated and will be removed in next major release of SuperGradients. " + "You don't need to wrap your model anymore, simply remove it and everything will work as expected.", + DeprecationWarning, + ) + self.module = module # that I actually define. def forward(self, x): return self.module(x) +def is_model_wrapped(model: nn.Module) -> bool: + return isinstance(model, (nn.DataParallel, DistributedDataParallel, WrappedModel)) + + +def unwrap_model(model: Union[nn.Module, nn.DataParallel, DistributedDataParallel]) -> nn.Module: + """ + Get the real model from a model wrapper (DataParallel, DistributedDataParallel) + + :param model: + :return: + """ + if is_model_wrapped(model): + return model.module + elif isinstance(model, nn.Module): + return model + raise ValueError(f"Unknown model type: {type(model)}") + + def arch_params_deprecated(func): """ Since initialization of arch_params is deprecated and will be removed, this decorator will be used to wrap the _init_ diff --git a/src/super_gradients/training/utils/weight_averaging_utils.py b/src/super_gradients/training/utils/weight_averaging_utils.py index ab04fa5bc4..ecbccadb7c 100755 --- a/src/super_gradients/training/utils/weight_averaging_utils.py +++ b/src/super_gradients/training/utils/weight_averaging_utils.py @@ -2,7 +2,7 @@ import torch import numpy as np from super_gradients.training.utils.checkpoint_utils import read_ckpt_state_dict -from super_gradients.training.utils.utils import move_state_dict_to_device +from super_gradients.training.utils.utils import move_state_dict_to_device, unwrap_model class ModelWeightAveraging: @@ -60,7 +60,7 @@ def update_snapshots_dict(self, model, validation_results_dict): require_update, update_ind = self._is_better(averaging_snapshots_dict, validation_results_dict) if require_update: # moving state dict to cpu - new_sd = model.state_dict() + new_sd = unwrap_model(model).state_dict() new_sd = move_state_dict_to_device(new_sd, "cpu") averaging_snapshots_dict["snapshot" + str(update_ind)] = new_sd diff --git a/tests/unit_tests/kd_ema_test.py b/tests/unit_tests/kd_ema_test.py index 7dc4cf34ab..1f59084fe7 100644 --- a/tests/unit_tests/kd_ema_test.py +++ b/tests/unit_tests/kd_ema_test.py @@ -5,7 +5,7 @@ from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader from super_gradients.training.kd_trainer import KDTrainer import torch -from super_gradients.training.utils.utils import check_models_have_same_weights +from super_gradients.training.utils.utils import check_models_have_same_weights, unwrap_model from super_gradients.training.metrics import Accuracy from super_gradients.training.losses.kd_losses import KDLogitsLoss from super_gradients.common.object_names import Models @@ -52,8 +52,8 @@ def test_teacher_ema_not_duplicated(self): valid_loader=classification_test_dataloader(), ) - self.assertTrue(kd_model.ema_model.ema.module.teacher is kd_model.net.module.teacher) - self.assertTrue(kd_model.ema_model.ema.module.student is not kd_model.net.module.student) + self.assertTrue(unwrap_model(kd_model.ema_model.ema).teacher is unwrap_model(kd_model.net).teacher) + self.assertTrue(unwrap_model(kd_model.ema_model.ema).student is not unwrap_model(kd_model.net).student) def test_kd_ckpt_reload_net(self): """Check that the KD trainer load correctly from checkpoint when "load_ema_as_net=False".""" @@ -100,10 +100,10 @@ def test_kd_ckpt_reload_net(self): self.assertTrue(not check_models_have_same_weights(reloaded_net, ema_model)) # loaded student ema == loaded student net (since load_ema_as_net = False) - self.assertTrue(not check_models_have_same_weights(reloaded_ema_model.module.student, reloaded_net.module.student)) + self.assertTrue(not check_models_have_same_weights(reloaded_ema_model.student, reloaded_net.student)) # loaded teacher ema == loaded teacher net (teacher always loads ema) - self.assertTrue(check_models_have_same_weights(reloaded_ema_model.module.teacher, reloaded_net.module.teacher)) + self.assertTrue(check_models_have_same_weights(reloaded_ema_model.teacher, reloaded_net.teacher)) if __name__ == "__main__": diff --git a/tests/unit_tests/kd_trainer_test.py b/tests/unit_tests/kd_trainer_test.py index 2b178de217..3b866e7b2a 100644 --- a/tests/unit_tests/kd_trainer_test.py +++ b/tests/unit_tests/kd_trainer_test.py @@ -81,10 +81,10 @@ def test_train_kd_module_external_models(self): ) # TEACHER WEIGHT'S SHOULD REMAIN THE SAME - self.assertTrue(check_models_have_same_weights(teacher_model, sg_model.net.module.teacher)) + self.assertTrue(check_models_have_same_weights(teacher_model, sg_model.net.teacher)) # STUDENT WEIGHT'S SHOULD NOT REMAIN THE SAME - self.assertFalse(check_models_have_same_weights(student_model, sg_model.net.module.student)) + self.assertFalse(check_models_have_same_weights(student_model, sg_model.net.student)) def test_train_model_with_input_adapter(self): kd_trainer = KDTrainer("train_kd_module_with_with_input_adapter") @@ -105,7 +105,7 @@ def test_train_model_with_input_adapter(self): valid_loader=classification_test_dataloader(), ) - self.assertEqual(kd_trainer.net.module.teacher_input_adapter, adapter) + self.assertEqual(kd_trainer.net.teacher_input_adapter, adapter) def test_load_ckpt_best_for_student(self): kd_trainer = KDTrainer("test_load_ckpt_best") @@ -124,7 +124,7 @@ def test_load_ckpt_best_for_student(self): student_reloaded = models.get(Models.RESNET18, arch_params={"num_classes": 5}, checkpoint_path=best_student_ckpt) - self.assertTrue(check_models_have_same_weights(student_reloaded, kd_trainer.net.module.student)) + self.assertTrue(check_models_have_same_weights(student_reloaded, kd_trainer.net.student)) def test_load_ckpt_best_for_student_with_ema(self): kd_trainer = KDTrainer("test_load_ckpt_best") @@ -146,7 +146,7 @@ def test_load_ckpt_best_for_student_with_ema(self): student_reloaded = models.get(Models.RESNET18, arch_params={"num_classes": 5}, checkpoint_path=best_student_ckpt) - self.assertTrue(check_models_have_same_weights(student_reloaded, kd_trainer.ema_model.ema.module.student)) + self.assertTrue(check_models_have_same_weights(student_reloaded, kd_trainer.ema_model.ema.student)) def test_resume_kd_training(self): kd_trainer = KDTrainer("test_resume_training_start")