diff --git a/mmengine/_strategy/base.py b/mmengine/_strategy/base.py
index a713da9a70..b555df9e94 100644
--- a/mmengine/_strategy/base.py
+++ b/mmengine/_strategy/base.py
@@ -322,7 +322,8 @@ def compile_model(
         Returns:
             nn.Module: Compiled model.
         """
-        if isinstance(compile, bool) and not compile:
+        if  isinstance(compile, bool) and not compile or \
+            isinstance(compile, dict) and not compile.get('disable', False):
             return model
 
         assert digit_version(TORCH_VERSION) >= digit_version('2.0.0'), (
diff --git a/mmengine/_strategy/fsdp.py b/mmengine/_strategy/fsdp.py
index 0788fafdab..124dfd7c57 100644
--- a/mmengine/_strategy/fsdp.py
+++ b/mmengine/_strategy/fsdp.py
@@ -408,7 +408,7 @@ def load_optim_state_dict(self, state_dict: dict) -> None:
                 ``optimizer.state_dict()``
         """
         optim_state_dict = FSDP.optim_state_dict_to_load(
-            state_dict, self.model, self.optim_wrapper.optimizer)
+            self.model, self.optim_wrapper.optimizer, state_dict)
         self.optim_wrapper.load_state_dict(optim_state_dict)
 
     def _init_state_dict_cfg(self, state_dict_cfg: Union[str, dict]) -> None:
diff --git a/mmengine/model/base_module.py b/mmengine/model/base_module.py
index 3cfe0b14a8..276e6fe218 100644
--- a/mmengine/model/base_module.py
+++ b/mmengine/model/base_module.py
@@ -65,7 +65,6 @@ def is_init(self, value):
 
     def init_weights(self):
         """Initialize the weights."""
-
         is_top_level_module = False
         # check if it is top-level module
         if not hasattr(self, '_params_init_info'):
diff --git a/mmengine/optim/optimizer/amp_optimizer_wrapper.py b/mmengine/optim/optimizer/amp_optimizer_wrapper.py
index 4f3323f2cc..60200924b5 100644
--- a/mmengine/optim/optimizer/amp_optimizer_wrapper.py
+++ b/mmengine/optim/optimizer/amp_optimizer_wrapper.py
@@ -1,5 +1,6 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 from contextlib import contextmanager
+from functools import partial
 from typing import Union
 
 import torch
@@ -17,7 +18,8 @@
 elif is_mlu_available():
     from torch.mlu.amp import GradScaler
 else:
-    from torch.cuda.amp import GradScaler
+    from torch.amp import GradScaler as amp_GradScaler
+    GradScaler = partial(amp_GradScaler, device='cuda')
 
 
 @OPTIM_WRAPPERS.register_module()
diff --git a/mmengine/optim/optimizer/builder.py b/mmengine/optim/optimizer/builder.py
index 7b4090ba7a..af98043b7f 100644
--- a/mmengine/optim/optimizer/builder.py
+++ b/mmengine/optim/optimizer/builder.py
@@ -8,7 +8,9 @@
 
 from mmengine.config import Config, ConfigDict
 from mmengine.device import is_npu_available, is_npu_support_full_precision
+from mmengine.logging.logger import print_log
 from mmengine.registry import OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS
+from .default_constructor import DefaultOptimWrapperConstructor
 from .optimizer_wrapper import OptimWrapper
 
 
@@ -170,7 +172,10 @@ def register_transformers_optimizers():
     except ImportError:
         pass
     else:
-        OPTIMIZERS.register_module(name='Adafactor', module=Adafactor)
+        try:
+            OPTIMIZERS.register_module(name='Adafactor', module=Adafactor)
+        except KeyError as e:
+            pass
         transformer_optimizers.append('Adafactor')
     return transformer_optimizers
 
@@ -196,8 +201,9 @@ def build_optim_wrapper(model: nn.Module,
         OptimWrapper: The built optimizer wrapper.
     """
     optim_wrapper_cfg = copy.deepcopy(cfg)
-    constructor_type = optim_wrapper_cfg.pop('constructor',
-                                             'DefaultOptimWrapperConstructor')
+    constructor_cfg = optim_wrapper_cfg.pop('constructor', None)
+    if constructor_cfg is None:
+        constructor_cfg = dict(type=DefaultOptimWrapperConstructor)
     paramwise_cfg = optim_wrapper_cfg.pop('paramwise_cfg', None)
 
     # Since the current generation of NPU(Ascend 910) only supports
@@ -205,11 +211,12 @@ def build_optim_wrapper(model: nn.Module,
     # to make the training normal
     if is_npu_available() and not is_npu_support_full_precision():
         optim_wrapper_cfg['type'] = 'AmpOptimWrapper'
+    
+    constructor_cfg.update(dict(
+        optim_wrapper_cfg=optim_wrapper_cfg,
+        paramwise_cfg=paramwise_cfg
+    ))
 
-    optim_wrapper_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build(
-        dict(
-            type=constructor_type,
-            optim_wrapper_cfg=optim_wrapper_cfg,
-            paramwise_cfg=paramwise_cfg))
+    optim_wrapper_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build(constructor_cfg)
     optim_wrapper = optim_wrapper_constructor(model)
     return optim_wrapper
diff --git a/mmengine/runner/checkpoint.py b/mmengine/runner/checkpoint.py
index 2bf5f50f7c..fa0a1eb520 100644
--- a/mmengine/runner/checkpoint.py
+++ b/mmengine/runner/checkpoint.py
@@ -344,7 +344,7 @@ def load_from_local(filename, map_location):
     filename = osp.expanduser(filename)
     if not osp.isfile(filename):
         raise FileNotFoundError(f'{filename} can not be found.')
-    checkpoint = torch.load(filename, map_location=map_location)
+    checkpoint = torch.load(filename, map_location=map_location, weights_only=False)
     return checkpoint
 
 
diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py
index 5a678db7b9..065bf9243c 100644
--- a/mmengine/runner/loops.py
+++ b/mmengine/runner/loops.py
@@ -377,7 +377,6 @@ def run(self) -> dict:
         self.val_loss.clear()
         for idx, data_batch in enumerate(self.dataloader):
             self.run_iter(idx, data_batch)
-
         # compute metrics
         metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
 
diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py
index 7d1f655aad..9bbbcaedce 100644
--- a/mmengine/runner/runner.py
+++ b/mmengine/runner/runner.py
@@ -1,5 +1,6 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 import copy
+import inspect
 import logging
 import os
 import os.path as osp
@@ -902,8 +903,18 @@ def wrap_model(
                 find_unused_parameters=find_unused_parameters)
         else:
             model_wrapper_cfg.setdefault('type', 'MMDistributedDataParallel')
-            model_wrapper_type = MODEL_WRAPPERS.get(
-                model_wrapper_cfg.get('type'))  # type: ignore
+            model_wrapper_type = model_wrapper_cfg.get('type')
+            if isinstance(model_wrapper_type, str):
+                model_wrapper_type = MODEL_WRAPPERS.get(model_wrapper_type)  # type: ignore
+            elif inspect.isclass(model_wrapper_type):
+                pass
+            else:
+                raise KeyError(
+                        f'{model_wrapper_type} is not in the '
+                        'registry. Please check whether the value of '
+                        f'`{model_wrapper_type}` is correct or it was registered '
+                        'as expected. More details can be found at https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#import-the-custom-module'  # noqa: E501
+                    )
             default_args: dict = dict()
             if issubclass(
                     model_wrapper_type,  # type: ignore
diff --git a/mmengine/visualization/vis_backend.py b/mmengine/visualization/vis_backend.py
index b752ec85a7..a5bf7d88e7 100644
--- a/mmengine/visualization/vis_backend.py
+++ b/mmengine/visualization/vis_backend.py
@@ -604,7 +604,8 @@ def add_scalar(self,
                       (int, float, torch.Tensor, np.ndarray, np.number)):
             self._tensorboard.add_scalar(name, value, step)
         else:
-            warnings.warn(f'Got {type(value)}, but numpy array, torch tensor, '
+            warnings.warn(f'Got type {type(value)} with name {name}, '
+                           'but numpy array, torch tensor, '
                           f'int or float are expected. skip it!')
 
     @force_init_env