diff --git a/monai/networks/nets/mednext.py b/monai/networks/nets/mednext.py
index 427572ba60..767eb65e0b 100644
--- a/monai/networks/nets/mednext.py
+++ b/monai/networks/nets/mednext.py
@@ -19,6 +19,7 @@
 
 import torch
 import torch.nn as nn
+from torch.nn.functional import interpolate
 
 from monai.networks.blocks.mednext_block import MedNeXtBlock, MedNeXtDownBlock, MedNeXtOutBlock, MedNeXtUpBlock
 
@@ -57,7 +58,16 @@ class MedNeXt(nn.Module):
         decoder_expansion_ratio: expansion ratio for decoder blocks. Defaults to 2.
         bottleneck_expansion_ratio: expansion ratio for bottleneck blocks. Defaults to 2.
         kernel_size: kernel size for convolutions. Defaults to 7.
-        deep_supervision: whether to use deep supervision. Defaults to False.
+        deep_supervision: whether to use deep supervision. Defaults to ``False``.
+            If ``True``, in training mode, the forward function will output not only the final feature map
+            (from the `out_0` block), but also the feature maps that come from the intermediate up sample layers.
+            In order to unify the return type, all intermediate feature maps are interpolated into the same size
+            as the final feature map and stacked together (with a new dimension in the first axis) into one single tensor.
+            For instance, if there are two intermediate feature maps with shapes: (1, 2, 16, 12) and
+            (1, 2, 8, 6), and the final feature map has the shape (1, 2, 32, 24), then all intermediate feature maps
+            will be interpolated into (1, 2, 32, 24), and the stacked tensor will has the shape (1, 3, 2, 32, 24).
+            When calculating the loss, you can use torch.unbind to get all feature maps can compute the loss
+            one by one with the ground truth, then do a weighted average for all losses to achieve the final loss.
         use_residual_connection: whether to use residual connections in standard, down and up blocks. Defaults to False.
         blocks_down: number of blocks in each encoder stage. Defaults to [2, 2, 2, 2].
         blocks_bottleneck: number of blocks in bottleneck stage. Defaults to 2.
@@ -260,7 +270,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor | Sequence[torch.Tensor]:
 
         # Return output(s)
         if self.do_ds and self.training:
-            return (x, *ds_outputs[::-1])
+            out_all = [x]
+            for feature_map in ds_outputs[::-1]:
+                out_all.append(interpolate(feature_map, x.shape[2:]))
+            return torch.stack(out_all, dim=1)
         else:
             return x
 
diff --git a/tests/test_mednext.py b/tests/test_mednext.py
index b4ba4f9939..4c715d9282 100644
--- a/tests/test_mednext.py
+++ b/tests/test_mednext.py
@@ -75,8 +75,10 @@ def test_shape(self, input_param, input_shape, expected_shape):
         with eval_mode(net):
             result = net(torch.randn(input_shape).to(device))
             if input_param["deep_supervision"] and net.training:
-                assert isinstance(result, tuple)
-                self.assertEqual(result[0].shape, expected_shape, msg=str(input_param))
+                assert isinstance(result, torch.Tensor)
+                result = torch.unbind(result, dim=1)
+                for r in result:
+                    self.assertEqual(r.shape, expected_shape, msg=str(input_param))
             else:
                 self.assertEqual(result.shape, expected_shape, msg=str(input_param))
 
@@ -87,8 +89,10 @@ def test_shape2(self, input_param, input_shape, expected_shape):
         net.train()
         result = net(torch.randn(input_shape).to(device))
         if input_param["deep_supervision"]:
-            assert isinstance(result, tuple)
-            self.assertEqual(result[0].shape, expected_shape, msg=str(input_param))
+            assert isinstance(result, torch.Tensor)
+            result = torch.unbind(result, dim=1)
+            for r in result:
+                self.assertEqual(r.shape, expected_shape, msg=str(input_param))
         else:
             assert isinstance(result, torch.Tensor)
             self.assertEqual(result.shape, expected_shape, msg=str(input_param))