diff --git a/Makefile b/Makefile
index 6b950b02f0..6720befe96 100644
--- a/Makefile
+++ b/Makefile
@@ -8,9 +8,9 @@ yolo_nas_integration_tests:
 	python -m unittest tests/integration_tests/yolo_nas_integration_test.py
 
 recipe_accuracy_tests:
-	python3.8 src/super_gradients/examples/convert_recipe_example/convert_recipe_example.py --config-name=cifar10_conversion_params experiment_name=shortened_cifar10_resnet_accuracy_test
-	python3.8 src/super_gradients/train_from_recipe.py --config-name=coco2017_pose_dekr_w32_no_dc experiment_name=shortened_coco2017_pose_dekr_w32_ap_test epochs=1 batch_size=4 val_batch_size=8 training_hyperparams.lr_warmup_steps=0 training_hyperparams.average_best_models=False training_hyperparams.max_train_batches=1000 training_hyperparams.max_valid_batches=100 multi_gpu=DDP num_gpus=4
-	python3.8 src/super_gradients/train_from_recipe.py --config-name=cifar10_resnet               experiment_name=shortened_cifar10_resnet_accuracy_test   epochs=100 training_hyperparams.average_best_models=False multi_gpu=DDP num_gpus=4
-	python3.8 src/super_gradients/train_from_recipe.py --config-name=coco2017_yolox               experiment_name=shortened_coco2017_yolox_n_map_test      epochs=10  architecture=yolox_n training_hyperparams.loss=yolox_fast_loss training_hyperparams.average_best_models=False multi_gpu=DDP num_gpus=4
-	python3.8 src/super_gradients/train_from_recipe.py --config-name=cityscapes_regseg48          experiment_name=shortened_cityscapes_regseg48_iou_test   epochs=10 training_hyperparams.average_best_models=False multi_gpu=DDP num_gpus=4
+	python src/super_gradients/train_from_recipe.py --config-name=coco2017_pose_dekr_w32_no_dc experiment_name=shortened_coco2017_pose_dekr_w32_ap_test epochs=1 batch_size=4 val_batch_size=8 training_hyperparams.lr_warmup_steps=0 training_hyperparams.average_best_models=False training_hyperparams.max_train_batches=1000 training_hyperparams.max_valid_batches=100 multi_gpu=DDP num_gpus=4
+	python src/super_gradients/train_from_recipe.py --config-name=cifar10_resnet               experiment_name=shortened_cifar10_resnet_accuracy_test   epochs=100 training_hyperparams.average_best_models=False multi_gpu=DDP num_gpus=4
+	python src/super_gradients/train_from_recipe.py --config-name=coco2017_yolox               experiment_name=shortened_coco2017_yolox_n_map_test      epochs=10  architecture=yolox_n training_hyperparams.loss=yolox_fast_loss training_hyperparams.average_best_models=False multi_gpu=DDP num_gpus=4
+	python src/super_gradients/train_from_recipe.py --config-name=cityscapes_regseg48          experiment_name=shortened_cityscapes_regseg48_iou_test   epochs=10 training_hyperparams.average_best_models=False multi_gpu=DDP num_gpus=4
+	python src/super_gradients/examples/convert_recipe_example/convert_recipe_example.py --config-name=cifar10_conversion_params experiment_name=shortened_cifar10_resnet_accuracy_test
 	coverage run --source=super_gradients -m unittest tests/deci_core_recipe_test_suite_runner.py
diff --git a/src/super_gradients/training/kd_trainer/kd_trainer.py b/src/super_gradients/training/kd_trainer/kd_trainer.py
index 97a4c2b1db..3cce2373d2 100644
--- a/src/super_gradients/training/kd_trainer/kd_trainer.py
+++ b/src/super_gradients/training/kd_trainer/kd_trainer.py
@@ -27,6 +27,7 @@
 from super_gradients.training.utils.checkpoint_utils import read_ckpt_state_dict, load_checkpoint_to_model
 from super_gradients.training.utils.distributed_training_utils import setup_device
 from super_gradients.training.utils.ema import KDModelEMA
+from super_gradients.training.utils.utils import unwrap_model
 
 logger = get_logger(__name__)
 
@@ -211,7 +212,7 @@ def _load_checkpoint_to_model(self):
          the entire KD network following the same logic as in Trainer.
         """
         teacher_checkpoint_path = get_param(self.checkpoint_params, "teacher_checkpoint_path")
-        teacher_net = self.net.module.teacher
+        teacher_net = unwrap_model(self.net).teacher
 
         if teacher_checkpoint_path is not None:
 
@@ -271,12 +272,12 @@ def _save_best_checkpoint(self, epoch, state):
         Overrides parent best_ckpt saving to modify the state dict so that we only save the student.
         """
         if self.ema:
-            best_net = core_utils.WrappedModel(self.ema_model.ema.module.student)
+            best_net = self.ema_model.ema.student
             state.pop("ema_net")
         else:
-            best_net = core_utils.WrappedModel(self.net.module.student)
+            best_net = self.net.student
 
-        state["net"] = best_net.state_dict()
+        state["net"] = unwrap_model(best_net).state_dict()
         self.sg_logger.add_checkpoint(tag=self.ckpt_best_name, state_dict=state, global_step=epoch)
 
     def train(
diff --git a/src/super_gradients/training/sg_trainer/sg_trainer.py b/src/super_gradients/training/sg_trainer/sg_trainer.py
index 1a3b38e51b..0ce5ad0d94 100755
--- a/src/super_gradients/training/sg_trainer/sg_trainer.py
+++ b/src/super_gradients/training/sg_trainer/sg_trainer.py
@@ -66,7 +66,7 @@
 from super_gradients.training.utils.ema import ModelEMA
 from super_gradients.training.utils.optimizer_utils import build_optimizer
 from super_gradients.training.utils.sg_trainer_utils import MonitoredValue, log_main_training_params
-from super_gradients.training.utils.utils import fuzzy_idx_in_list
+from super_gradients.training.utils.utils import fuzzy_idx_in_list, unwrap_model
 from super_gradients.training.utils.weight_averaging_utils import ModelWeightAveraging
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.utils import random_seed
@@ -396,9 +396,6 @@ def _net_to_device(self):
             local_rank = int(device_config.device.split(":")[1])
             self.net = torch.nn.parallel.DistributedDataParallel(self.net, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
 
-        else:
-            self.net = core_utils.WrappedModel(self.net)
-
     def _train_epoch(self, epoch: int, silent_mode: bool = False) -> tuple:
         """
         train_epoch - A single epoch training procedure
@@ -601,7 +598,8 @@ def _save_checkpoint(
         metric = validation_results_dict[self.metric_to_watch]
 
         # BUILD THE state_dict
-        state = {"net": self.net.state_dict(), "acc": metric, "epoch": epoch}
+        state = {"net": unwrap_model(self.net).state_dict(), "acc": metric, "epoch": epoch}
+
         if optimizer is not None:
             state["optimizer_state_dict"] = optimizer.state_dict()
 
@@ -609,7 +607,7 @@ def _save_checkpoint(
             state["scaler_state_dict"] = self.scaler.state_dict()
 
         if self.ema:
-            state["ema_net"] = self.ema_model.ema.state_dict()
+            state["ema_net"] = unwrap_model(self.ema_model.ema).state_dict()
 
         processing_params = self._get_preprocessing_from_valid_loader()
         if processing_params is not None:
@@ -636,7 +634,7 @@ def _save_checkpoint(
             logger.info("Best checkpoint overriden: validation " + self.metric_to_watch + ": " + str(metric))
 
         if self.training_params.average_best_models:
-            net_for_averaging = self.ema_model.ema if self.ema else self.net
+            net_for_averaging = unwrap_model(self.ema_model.ema if self.ema else self.net)
 
             state["net"] = self.model_weight_averaging.get_average_model(net_for_averaging, validation_results_dict=validation_results_dict)
             self.sg_logger.add_checkpoint(tag=self.average_model_checkpoint_filename, state_dict=state, global_step=epoch)
@@ -652,7 +650,7 @@ def _prep_net_for_train(self) -> None:
         self._net_to_device()
 
         # SET THE FLAG FOR DIFFERENT PARAMETER GROUP OPTIMIZER UPDATE
-        self.update_param_groups = hasattr(self.net.module, "update_param_groups")
+        self.update_param_groups = hasattr(unwrap_model(self.net), "update_param_groups")
 
         self.checkpoint = {}
         self.strict_load = core_utils.get_param(self.training_params, "resume_strict_load", StrictLoad.ON)
@@ -1161,7 +1159,9 @@ def forward(self, inputs, targets):
         if not self.ddp_silent_mode:
             if self.training_params.dataset_statistics:
                 dataset_statistics_logger = DatasetStatisticsTensorboardLogger(self.sg_logger)
-                dataset_statistics_logger.analyze(self.train_loader, all_classes=self.classes, title="Train-set", anchors=self.net.module.arch_params.anchors)
+                dataset_statistics_logger.analyze(
+                    self.train_loader, all_classes=self.classes, title="Train-set", anchors=unwrap_model(self.net).arch_params.anchors
+                )
                 dataset_statistics_logger.analyze(self.valid_loader, all_classes=self.classes, title="val-set")
 
         sg_trainer_utils.log_uncaught_exceptions(logger)
@@ -1175,7 +1175,7 @@ def forward(self, inputs, targets):
         if isinstance(self.training_params.optimizer, str) or (
             inspect.isclass(self.training_params.optimizer) and issubclass(self.training_params.optimizer, torch.optim.Optimizer)
         ):
-            self.optimizer = build_optimizer(net=self.net, lr=self.training_params.initial_lr, training_params=self.training_params)
+            self.optimizer = build_optimizer(net=unwrap_model(self.net), lr=self.training_params.initial_lr, training_params=self.training_params)
         elif isinstance(self.training_params.optimizer, torch.optim.Optimizer):
             self.optimizer = self.training_params.optimizer
         else:
@@ -1248,7 +1248,7 @@ def forward(self, inputs, targets):
 
         processing_params = self._get_preprocessing_from_valid_loader()
         if processing_params is not None:
-            self.net.module.set_dataset_processing_params(**processing_params)
+            unwrap_model(self.net).set_dataset_processing_params(**processing_params)
 
         try:
             # HEADERS OF THE TRAINING PROGRESS
@@ -1295,7 +1295,7 @@ def forward(self, inputs, targets):
                             num_gpus=get_world_size(),
                         )
 
-                # model switch - we replace self.net.module with the ema model for the testing and saving part
+                # model switch - we replace self.net with the ema model for the testing and saving part
                 # and then switch it back before the next training epoch
                 if self.ema:
                     self.ema_model.update_attr(self.net)
@@ -1355,7 +1355,7 @@ def forward(self, inputs, targets):
     def _get_preprocessing_from_valid_loader(self) -> Optional[dict]:
         valid_loader = self.valid_loader
 
-        if isinstance(self.net.module, HasPredict) and isinstance(valid_loader.dataset, HasPreprocessingParams):
+        if isinstance(unwrap_model(self.net), HasPredict) and isinstance(valid_loader.dataset, HasPreprocessingParams):
             try:
                 return valid_loader.dataset.get_dataset_preprocessing_params()
             except Exception as e:
@@ -1413,7 +1413,7 @@ def _initialize_mixed_precision(self, mixed_precision_enabled: bool):
                 def hook(module, _):
                     module.forward = MultiGPUModeAutocastWrapper(module.forward)
 
-                self.net.module.register_forward_pre_hook(hook=hook)
+                unwrap_model(self.net).register_forward_pre_hook(hook=hook)
 
             if self.load_checkpoint:
                 scaler_state_dict = core_utils.get_param(self.checkpoint, "scaler_state_dict")
@@ -1439,7 +1439,7 @@ def _validate_final_average_model(self, cleanup_snapshots_pkl_file=False):
         with wait_for_the_master(local_rank):
             average_model_sd = read_ckpt_state_dict(average_model_ckpt_path)["net"]
 
-        self.net.load_state_dict(average_model_sd)
+        unwrap_model(self.net).load_state_dict(average_model_sd)
         # testing the averaged model and save instead of best model if needed
         averaged_model_results_dict = self._validate_epoch(epoch=self.max_epochs)
 
@@ -1462,7 +1462,7 @@ def get_arch_params(self):
 
     @property
     def get_structure(self):
-        return self.net.module.structure
+        return unwrap_model(self.net).structure
 
     @property
     def get_architecture(self):
@@ -1494,7 +1494,8 @@ def _re_build_model(self, arch_params={}):
 
         if device_config.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL:
             logger.warning("Warning: distributed training is not supported in re_build_model()")
-        self.net = torch.nn.DataParallel(self.net, device_ids=get_device_ids()) if device_config.multi_gpu else core_utils.WrappedModel(self.net)
+        if device_config.multi_gpu == MultiGPUMode.DATA_PARALLEL:
+            self.net = torch.nn.DataParallel(self.net, device_ids=get_device_ids())
 
     @property
     def get_module(self):
@@ -1635,7 +1636,7 @@ def _initialize_sg_logger_objects(self, additional_configs_to_log: Dict = None):
             if "model_name" in get_callable_param_names(sg_logger_cls.__init__):
                 if sg_logger_params.get("model_name") is None:
                     # Use the model name used in `models.get(...)` if relevant
-                    sg_logger_params["model_name"] = get_model_name(self.net.module)
+                    sg_logger_params["model_name"] = get_model_name(unwrap_model(self.net))
 
                 if sg_logger_params["model_name"] is None:
                     raise ValueError(
diff --git a/src/super_gradients/training/utils/__init__.py b/src/super_gradients/training/utils/__init__.py
index f75d9c7438..4cf763988d 100755
--- a/src/super_gradients/training/utils/__init__.py
+++ b/src/super_gradients/training/utils/__init__.py
@@ -1,7 +1,6 @@
 from super_gradients.training.utils.utils import (
     Timer,
     HpmStruct,
-    WrappedModel,
     convert_to_tensor,
     get_param,
     tensor_container_to_device,
@@ -17,7 +16,6 @@
 __all__ = [
     "Timer",
     "HpmStruct",
-    "WrappedModel",
     "convert_to_tensor",
     "get_param",
     "tensor_container_to_device",
diff --git a/src/super_gradients/training/utils/callbacks/callbacks.py b/src/super_gradients/training/utils/callbacks/callbacks.py
index d3d714f87c..1f2991d4e7 100644
--- a/src/super_gradients/training/utils/callbacks/callbacks.py
+++ b/src/super_gradients/training/utils/callbacks/callbacks.py
@@ -23,7 +23,7 @@
 from super_gradients.training.utils.segmentation_utils import BinarySegmentationVisualization
 from super_gradients.common.environment.ddp_utils import multi_process_safe
 from super_gradients.common.environment.checkpoints_dir_utils import get_project_checkpoints_dir_path
-
+from super_gradients.training.utils.utils import unwrap_model
 
 logger = get_logger(__name__)
 
@@ -75,7 +75,7 @@ def __init__(self, model_name: str, input_dimensions: Sequence[int], primary_bat
         self.atol = kwargs.get("atol", 1e-05)
 
     def __call__(self, context: PhaseContext):
-        model = copy.deepcopy(context.net.module)
+        model = copy.deepcopy(unwrap_model(context.net))
         model = model.cpu()
         model.eval()  # Put model into eval mode
 
@@ -204,12 +204,12 @@ def __call__(self, context: PhaseContext) -> None:
         :param context: Training phase context
         """
         try:
-            model = copy.deepcopy(context.net)
+            model = copy.deepcopy(unwrap_model(context.net))
             model_state_dict_path = os.path.join(context.ckpt_dir, self.ckpt_name)
             model_state_dict = torch.load(model_state_dict_path)["net"]
             model.load_state_dict(state_dict=model_state_dict)
 
-            model = model.module.cpu()
+            model = model.cpu()
             if hasattr(model, "prep_model_for_conversion"):
                 model.prep_model_for_conversion(input_size=self.input_dimensions)
 
@@ -267,7 +267,9 @@ def perform_scheduling(self, context: PhaseContext):
 
     def update_lr(self, optimizer, epoch, batch_idx=None):
         if self.update_param_groups:
-            param_groups = self.net.module.update_param_groups(optimizer.param_groups, self.lr, epoch, batch_idx, self.training_params, self.train_loader_len)
+            param_groups = unwrap_model(self.net).update_param_groups(
+                optimizer.param_groups, self.lr, epoch, batch_idx, self.training_params, self.train_loader_len
+            )
             optimizer.param_groups = param_groups
         else:
             # UPDATE THE OPTIMIZERS PARAMETER
@@ -373,7 +375,9 @@ def update_lr(self, optimizer, epoch, batch_idx=None):
         :return:
         """
         if self.update_param_groups:
-            param_groups = self.net.module.update_param_groups(optimizer.param_groups, self.lr, epoch, batch_idx, self.training_params, self.train_loader_len)
+            param_groups = unwrap_model(self.net).update_param_groups(
+                optimizer.param_groups, self.lr, epoch, batch_idx, self.training_params, self.train_loader_len
+            )
             optimizer.param_groups = param_groups
         else:
             # UPDATE THE OPTIMIZERS PARAMETER
diff --git a/src/super_gradients/training/utils/checkpoint_utils.py b/src/super_gradients/training/utils/checkpoint_utils.py
index 8dc8ebe35c..904a65d5de 100644
--- a/src/super_gradients/training/utils/checkpoint_utils.py
+++ b/src/super_gradients/training/utils/checkpoint_utils.py
@@ -14,6 +14,7 @@
 from super_gradients.module_interfaces import HasPredict
 from super_gradients.training.pretrained_models import MODEL_URLS
 from super_gradients.training.utils.distributed_training_utils import get_local_rank, wait_for_the_master
+from super_gradients.training.utils.utils import unwrap_model
 
 try:
     from torch.hub import download_url_to_file, load_state_dict_from_url
@@ -54,6 +55,13 @@ def adaptive_load_state_dict(net: torch.nn.Module, state_dict: dict, strict: Uni
     :return:
     """
     state_dict = state_dict["net"] if "net" in state_dict else state_dict
+
+    # This is a backward compatibility fix for checkpoints that were saved with DataParallel/DistributedDataParallel wrapper
+    # and contains "module." prefix in all keys
+    # If all keys start with "module.", then we remove it.
+    if all([key.startswith("module.") for key in state_dict.keys()]):
+        state_dict = collections.OrderedDict([(key[7:], value) for key, value in state_dict.items()])
+
     try:
         strict_bool = strict if isinstance(strict, bool) else strict != StrictLoad.OFF
         net.load_state_dict(state_dict, strict=strict_bool)
@@ -217,6 +225,8 @@ def load_checkpoint_to_model(
     if isinstance(strict, str):
         strict = StrictLoad(strict)
 
+    net = unwrap_model(net)
+
     if load_backbone and not hasattr(net, "backbone"):
         raise ValueError("No backbone attribute in net - Can't load backbone weights")
 
@@ -239,7 +249,7 @@ def load_checkpoint_to_model(
     message_model = "model" if not load_backbone else "model's backbone"
     logger.info("Successfully loaded " + message_model + " weights from " + ckpt_local_path + message_suffix)
 
-    if (isinstance(net, HasPredict) or (hasattr(net, "module") and isinstance(net.module, HasPredict))) and load_processing_params:
+    if (isinstance(net, HasPredict)) and load_processing_params:
         if "processing_params" not in checkpoint.keys():
             raise ValueError("Can't load processing params - could not find any stored in checkpoint file.")
         try:
@@ -275,7 +285,7 @@ def _yolox_ckpt_solver(ckpt_key, ckpt_val, model_key, model_val):
 
     if (
         ckpt_val.shape != model_val.shape
-        and ckpt_key == "module._backbone._modules_list.0.conv.conv.weight"
+        and (ckpt_key == "module._backbone._modules_list.0.conv.conv.weight" or ckpt_key == "_backbone._modules_list.0.conv.conv.weight")
         and model_key == "_backbone._modules_list.0.conv.weight"
     ):
         model_val.data[:, :, ::2, ::2] = ckpt_val.data[:, :3]
diff --git a/src/super_gradients/training/utils/ema.py b/src/super_gradients/training/utils/ema.py
index f57dd73494..1710f577cd 100755
--- a/src/super_gradients/training/utils/ema.py
+++ b/src/super_gradients/training/utils/ema.py
@@ -3,11 +3,11 @@
 from typing import Union
 
 import torch
+from super_gradients.training.utils.utils import unwrap_model
 from torch import nn
 
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.exceptions.factory_exceptions import UnknownTypeException
-from super_gradients.training import utils as core_utils
 from super_gradients.training.models import SgModule
 from super_gradients.training.models.kd_modules.kd_module import KDModule
 from super_gradients.training.utils.ema_decay_schedules import IDecayFunction, EMA_DECAY_FUNCTIONS
@@ -34,7 +34,7 @@ class ModelEMA:
     GPU assignment and distributed training wrappers.
     """
 
-    def __init__(self, model, decay: float, decay_function: IDecayFunction):
+    def __init__(self, model: nn.Module, decay: float, decay_function: IDecayFunction):
         """
         Init the EMA
         :param model: Union[SgModule, nn.Module], the training model to construct the EMA model by
@@ -46,6 +46,7 @@ def __init__(self, model, decay: float, decay_function: IDecayFunction):
                      its final value. beta=15 is ~40% of the training process.
         """
         # Create EMA
+        model = unwrap_model(model)
         self.ema = deepcopy(model)
         self.ema.eval()
         self.decay = decay
@@ -57,14 +58,14 @@ def __init__(self, model, decay: float, decay_function: IDecayFunction):
         get_include_attributes and get_exclude_attributes functions. for a nn.Module which is not a SgModule
         all non-private (not starting with '_') attributes will be updated (and only them).
         """
-        if isinstance(model.module, SgModule):
-            self.include_attributes = model.module.get_include_attributes()
-            self.exclude_attributes = model.module.get_exclude_attributes()
+        if isinstance(model, SgModule):
+            self.include_attributes = model.get_include_attributes()
+            self.exclude_attributes = model.get_exclude_attributes()
         else:
             warnings.warn("Warning: EMA should be used with SgModule instance. All attributes of the model will be " "included in EMA")
             self.include_attributes = []
             self.exclude_attributes = []
-        for p in self.ema.module.parameters():
+        for p in self.ema.parameters():
             p.requires_grad_(False)
 
     @classmethod
@@ -131,10 +132,11 @@ def update(self, model, step: int, total_steps: int):
         :param total_steps: Total training steps
         """
         # Update EMA parameters
+        model = unwrap_model(model)
         with torch.no_grad():
             decay = self.decay_function(self.decay, step, total_steps)
 
-            for ema_v, model_v in zip(self.ema.module.state_dict().values(), model.state_dict().values()):
+            for ema_v, model_v in zip(self.ema.state_dict().values(), model.state_dict().values()):
                 if ema_v.dtype.is_floating_point:
                     ema_v.copy_(ema_v * decay + (1.0 - decay) * model_v.detach())
 
@@ -147,7 +149,7 @@ def update_attr(self, model):
         attributes will be updated (and only them).
         :param model: the source model
         """
-        copy_attr(self.ema.module, model.module, self.include_attributes, self.exclude_attributes)
+        copy_attr(self.ema, unwrap_model(model), self.include_attributes, self.exclude_attributes)
 
 
 class KDModelEMA(ModelEMA):
@@ -172,15 +174,13 @@ def __init__(self, kd_model: KDModule, decay: float, decay_function: IDecayFunct
                      its final value. beta=15 is ~40% of the training process.
         """
         # Only work on the student (we don't want to update and to have a duplicate of the teacher)
-        super().__init__(model=core_utils.WrappedModel(kd_model.module.student), decay=decay, decay_function=decay_function)
+        super().__init__(model=kd_model.student, decay=decay, decay_function=decay_function)
 
         # Overwrite current ema attribute with combination of the student model EMA (current self.ema)
         # with already the instantiated teacher, to have the final KD EMA
-        self.ema = core_utils.WrappedModel(
-            KDModule(
-                arch_params=kd_model.module.arch_params,
-                student=self.ema.module,
-                teacher=kd_model.module.teacher,
-                run_teacher_on_eval=kd_model.module.run_teacher_on_eval,
-            )
+        self.ema = KDModule(
+            arch_params=kd_model.arch_params,
+            student=self.ema,
+            teacher=kd_model.teacher,
+            run_teacher_on_eval=kd_model.run_teacher_on_eval,
         )
diff --git a/src/super_gradients/training/utils/optimizer_utils.py b/src/super_gradients/training/utils/optimizer_utils.py
index 7844f67a9f..5f47a7f72e 100755
--- a/src/super_gradients/training/utils/optimizer_utils.py
+++ b/src/super_gradients/training/utils/optimizer_utils.py
@@ -1,7 +1,5 @@
-import torch.optim as optim
 import torch.nn as nn
-from torch.nn.modules.batchnorm import _BatchNorm
-from torch.nn.modules.conv import _ConvNd
+import torch.optim as optim
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.factories.optimizers_type_factory import OptimizersTypeFactory
 from super_gradients.training.params import (
@@ -12,6 +10,9 @@
 )
 from super_gradients.training.utils import get_param
 from super_gradients.training.utils.optimizers.rmsprop_tf import RMSpropTF
+from super_gradients.training.utils.utils import is_model_wrapped
+from torch.nn.modules.batchnorm import _BatchNorm
+from torch.nn.modules.conv import _ConvNd
 
 logger = get_logger(__name__)
 
@@ -86,6 +87,8 @@ def build_optimizer(net: nn.Module, lr: float, training_params) -> optim.Optimiz
         :param lr: initial learning rate
         :param training_params: training_parameters
     """
+    if is_model_wrapped(net):
+        raise ValueError("Argument net for build_optimizer must be an unwrapped model. " "Please use build_optimizer(unwrap_model(net), ...).")
     if isinstance(training_params.optimizer, str):
         optimizer_cls = OptimizersTypeFactory().get(training_params.optimizer)
     else:
@@ -96,14 +99,14 @@ def build_optimizer(net: nn.Module, lr: float, training_params) -> optim.Optimiz
 
     weight_decay = get_param(training_params.optimizer_params, "weight_decay", 0.0)
     # OPTIMIZER PARAM GROUPS ARE SET USING DEFAULT OR MODEL SPECIFIC INIT
-    if hasattr(net.module, "initialize_param_groups"):
+    if hasattr(net, "initialize_param_groups"):
         # INITIALIZE_PARAM_GROUPS MUST RETURN A LIST OF DICTS WITH 'named_params' AND OPTIMIZER's ATTRIBUTES PER GROUP
-        net_named_params = net.module.initialize_param_groups(lr, training_params)
+        net_named_params = net.initialize_param_groups(lr, training_params)
     else:
         net_named_params = [{"named_params": net.named_parameters()}]
 
     if training_params.zero_weight_decay_on_bias_and_bn:
-        optimizer_training_params = separate_zero_wd_params_groups_for_optimizer(net.module, net_named_params, weight_decay)
+        optimizer_training_params = separate_zero_wd_params_groups_for_optimizer(net, net_named_params, weight_decay)
 
     else:
         # Overwrite groups to include params instead of named params
diff --git a/src/super_gradients/training/utils/utils.py b/src/super_gradients/training/utils/utils.py
index 6a6f05a6be..a8caf0d6cc 100755
--- a/src/super_gradients/training/utils/utils.py
+++ b/src/super_gradients/training/utils/utils.py
@@ -6,6 +6,7 @@
 import time
 import inspect
 import typing
+import warnings
 from functools import lru_cache, wraps
 from importlib import import_module
 from itertools import islice
@@ -13,6 +14,7 @@
 from pathlib import Path
 from typing import Mapping, Optional, Tuple, Union, List, Dict, Any, Iterable
 from zipfile import ZipFile
+from torch.nn.parallel import DistributedDataParallel
 
 import numpy as np
 import torch
@@ -80,12 +82,36 @@ def validate(self):
 class WrappedModel(nn.Module):
     def __init__(self, module):
         super(WrappedModel, self).__init__()
+        warnings.warn(
+            "WrappedModel is deprecated and will be removed in next major release of SuperGradients. "
+            "You don't need to wrap your model anymore, simply remove it and everything will work as expected.",
+            DeprecationWarning,
+        )
+
         self.module = module  # that I actually define.
 
     def forward(self, x):
         return self.module(x)
 
 
+def is_model_wrapped(model: nn.Module) -> bool:
+    return isinstance(model, (nn.DataParallel, DistributedDataParallel, WrappedModel))
+
+
+def unwrap_model(model: Union[nn.Module, nn.DataParallel, DistributedDataParallel]) -> nn.Module:
+    """
+    Get the real model from a model wrapper (DataParallel, DistributedDataParallel)
+
+    :param model:
+    :return:
+    """
+    if is_model_wrapped(model):
+        return model.module
+    elif isinstance(model, nn.Module):
+        return model
+    raise ValueError(f"Unknown model type: {type(model)}")
+
+
 def arch_params_deprecated(func):
     """
     Since initialization of arch_params is deprecated and will be removed, this decorator will be used to wrap the _init_
diff --git a/src/super_gradients/training/utils/weight_averaging_utils.py b/src/super_gradients/training/utils/weight_averaging_utils.py
index ab04fa5bc4..ecbccadb7c 100755
--- a/src/super_gradients/training/utils/weight_averaging_utils.py
+++ b/src/super_gradients/training/utils/weight_averaging_utils.py
@@ -2,7 +2,7 @@
 import torch
 import numpy as np
 from super_gradients.training.utils.checkpoint_utils import read_ckpt_state_dict
-from super_gradients.training.utils.utils import move_state_dict_to_device
+from super_gradients.training.utils.utils import move_state_dict_to_device, unwrap_model
 
 
 class ModelWeightAveraging:
@@ -60,7 +60,7 @@ def update_snapshots_dict(self, model, validation_results_dict):
         require_update, update_ind = self._is_better(averaging_snapshots_dict, validation_results_dict)
         if require_update:
             # moving state dict to cpu
-            new_sd = model.state_dict()
+            new_sd = unwrap_model(model).state_dict()
             new_sd = move_state_dict_to_device(new_sd, "cpu")
 
             averaging_snapshots_dict["snapshot" + str(update_ind)] = new_sd
diff --git a/tests/unit_tests/kd_ema_test.py b/tests/unit_tests/kd_ema_test.py
index 7dc4cf34ab..1f59084fe7 100644
--- a/tests/unit_tests/kd_ema_test.py
+++ b/tests/unit_tests/kd_ema_test.py
@@ -5,7 +5,7 @@
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.kd_trainer import KDTrainer
 import torch
-from super_gradients.training.utils.utils import check_models_have_same_weights
+from super_gradients.training.utils.utils import check_models_have_same_weights, unwrap_model
 from super_gradients.training.metrics import Accuracy
 from super_gradients.training.losses.kd_losses import KDLogitsLoss
 from super_gradients.common.object_names import Models
@@ -52,8 +52,8 @@ def test_teacher_ema_not_duplicated(self):
             valid_loader=classification_test_dataloader(),
         )
 
-        self.assertTrue(kd_model.ema_model.ema.module.teacher is kd_model.net.module.teacher)
-        self.assertTrue(kd_model.ema_model.ema.module.student is not kd_model.net.module.student)
+        self.assertTrue(unwrap_model(kd_model.ema_model.ema).teacher is unwrap_model(kd_model.net).teacher)
+        self.assertTrue(unwrap_model(kd_model.ema_model.ema).student is not unwrap_model(kd_model.net).student)
 
     def test_kd_ckpt_reload_net(self):
         """Check that the KD trainer load correctly from checkpoint when "load_ema_as_net=False"."""
@@ -100,10 +100,10 @@ def test_kd_ckpt_reload_net(self):
         self.assertTrue(not check_models_have_same_weights(reloaded_net, ema_model))
 
         # loaded student ema == loaded  student net (since load_ema_as_net = False)
-        self.assertTrue(not check_models_have_same_weights(reloaded_ema_model.module.student, reloaded_net.module.student))
+        self.assertTrue(not check_models_have_same_weights(reloaded_ema_model.student, reloaded_net.student))
 
         # loaded teacher ema == loaded teacher net (teacher always loads ema)
-        self.assertTrue(check_models_have_same_weights(reloaded_ema_model.module.teacher, reloaded_net.module.teacher))
+        self.assertTrue(check_models_have_same_weights(reloaded_ema_model.teacher, reloaded_net.teacher))
 
 
 if __name__ == "__main__":
diff --git a/tests/unit_tests/kd_trainer_test.py b/tests/unit_tests/kd_trainer_test.py
index 2b178de217..3b866e7b2a 100644
--- a/tests/unit_tests/kd_trainer_test.py
+++ b/tests/unit_tests/kd_trainer_test.py
@@ -81,10 +81,10 @@ def test_train_kd_module_external_models(self):
         )
 
         # TEACHER WEIGHT'S SHOULD REMAIN THE SAME
-        self.assertTrue(check_models_have_same_weights(teacher_model, sg_model.net.module.teacher))
+        self.assertTrue(check_models_have_same_weights(teacher_model, sg_model.net.teacher))
 
         # STUDENT WEIGHT'S SHOULD NOT REMAIN THE SAME
-        self.assertFalse(check_models_have_same_weights(student_model, sg_model.net.module.student))
+        self.assertFalse(check_models_have_same_weights(student_model, sg_model.net.student))
 
     def test_train_model_with_input_adapter(self):
         kd_trainer = KDTrainer("train_kd_module_with_with_input_adapter")
@@ -105,7 +105,7 @@ def test_train_model_with_input_adapter(self):
             valid_loader=classification_test_dataloader(),
         )
 
-        self.assertEqual(kd_trainer.net.module.teacher_input_adapter, adapter)
+        self.assertEqual(kd_trainer.net.teacher_input_adapter, adapter)
 
     def test_load_ckpt_best_for_student(self):
         kd_trainer = KDTrainer("test_load_ckpt_best")
@@ -124,7 +124,7 @@ def test_load_ckpt_best_for_student(self):
 
         student_reloaded = models.get(Models.RESNET18, arch_params={"num_classes": 5}, checkpoint_path=best_student_ckpt)
 
-        self.assertTrue(check_models_have_same_weights(student_reloaded, kd_trainer.net.module.student))
+        self.assertTrue(check_models_have_same_weights(student_reloaded, kd_trainer.net.student))
 
     def test_load_ckpt_best_for_student_with_ema(self):
         kd_trainer = KDTrainer("test_load_ckpt_best")
@@ -146,7 +146,7 @@ def test_load_ckpt_best_for_student_with_ema(self):
 
         student_reloaded = models.get(Models.RESNET18, arch_params={"num_classes": 5}, checkpoint_path=best_student_ckpt)
 
-        self.assertTrue(check_models_have_same_weights(student_reloaded, kd_trainer.ema_model.ema.module.student))
+        self.assertTrue(check_models_have_same_weights(student_reloaded, kd_trainer.ema_model.ema.student))
 
     def test_resume_kd_training(self):
         kd_trainer = KDTrainer("test_resume_training_start")