diff --git a/mmseg/core/seg/sampler/ohem_pixel_sampler.py b/mmseg/core/seg/sampler/ohem_pixel_sampler.py
index 72ba941f03..833a28768c 100644
--- a/mmseg/core/seg/sampler/ohem_pixel_sampler.py
+++ b/mmseg/core/seg/sampler/ohem_pixel_sampler.py
@@ -1,5 +1,6 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 import torch
+import torch.nn as nn
 import torch.nn.functional as F
 
 from ..builder import PIXEL_SAMPLERS
@@ -62,14 +63,19 @@ def sample(self, seg_logit, seg_label):
                 threshold = max(min_threshold, self.thresh)
                 valid_seg_weight[seg_prob[valid_mask] < threshold] = 1.
             else:
+                if not isinstance(self.context.loss_decode, nn.ModuleList):
+                    losses_decode = [self.context.loss_decode]
+                else:
+                    losses_decode = self.context.loss_decode
                 losses = 0.0
-                for loss_module in self.context.loss_decode:
+                for loss_module in losses_decode:
                     losses += loss_module(
                         seg_logit,
                         seg_label,
                         weight=None,
                         ignore_index=self.context.ignore_index,
                         reduction_override='none')
+
                 # faster than topk according to https://github.com/pytorch/pytorch/issues/22812  # noqa
                 _, sort_indices = losses[valid_mask].sort(descending=True)
                 valid_seg_weight[sort_indices[:batch_kept]] = 1.
diff --git a/mmseg/models/decode_heads/decode_head.py b/mmseg/models/decode_heads/decode_head.py
index c36555eaf2..1443a81da2 100644
--- a/mmseg/models/decode_heads/decode_head.py
+++ b/mmseg/models/decode_heads/decode_head.py
@@ -83,11 +83,11 @@ def __init__(self,
 
         self.ignore_index = ignore_index
         self.align_corners = align_corners
-        self.loss_decode = nn.ModuleList()
 
         if isinstance(loss_decode, dict):
-            self.loss_decode.append(build_loss(loss_decode))
+            self.loss_decode = build_loss(loss_decode)
         elif isinstance(loss_decode, (list, tuple)):
+            self.loss_decode = nn.ModuleList()
             for loss in loss_decode:
                 self.loss_decode.append(build_loss(loss))
         else:
@@ -242,7 +242,12 @@ def losses(self, seg_logit, seg_label):
         else:
             seg_weight = None
         seg_label = seg_label.squeeze(1)
-        for loss_decode in self.loss_decode:
+
+        if not isinstance(self.loss_decode, nn.ModuleList):
+            losses_decode = [self.loss_decode]
+        else:
+            losses_decode = self.loss_decode
+        for loss_decode in losses_decode:
             if loss_decode.loss_name not in loss:
                 loss[loss_decode.loss_name] = loss_decode(
                     seg_logit,
diff --git a/mmseg/models/decode_heads/point_head.py b/mmseg/models/decode_heads/point_head.py
index 56dfd4ed8b..7276218053 100644
--- a/mmseg/models/decode_heads/point_head.py
+++ b/mmseg/models/decode_heads/point_head.py
@@ -249,9 +249,14 @@ def forward_test(self, inputs, prev_output, img_metas, test_cfg):
     def losses(self, point_logits, point_label):
         """Compute segmentation loss."""
         loss = dict()
-        for loss_module in self.loss_decode:
+        if not isinstance(self.loss_decode, nn.ModuleList):
+            losses_decode = [self.loss_decode]
+        else:
+            losses_decode = self.loss_decode
+        for loss_module in losses_decode:
             loss['point' + loss_module.loss_name] = loss_module(
                 point_logits, point_label, ignore_index=self.ignore_index)
+
         loss['acc_point'] = accuracy(point_logits, point_label)
         return loss
 
diff --git a/tests/test_models/test_heads/test_point_head.py b/tests/test_models/test_heads/test_point_head.py
index 6c5ea65768..142ab16c6c 100644
--- a/tests/test_models/test_heads/test_point_head.py
+++ b/tests/test_models/test_heads/test_point_head.py
@@ -21,3 +21,41 @@ def test_point_head():
         subdivision_steps=2, subdivision_num_points=8196, scale_factor=2)
     output = point_head.forward_test(inputs, prev_output, None, test_cfg)
     assert output.shape == (1, point_head.num_classes, 180, 180)
+
+    # test multiple losses case
+    inputs = [torch.randn(1, 32, 45, 45)]
+    point_head_multiple_losses = PointHead(
+        in_channels=[32],
+        in_index=[0],
+        channels=16,
+        num_classes=19,
+        loss_decode=[
+            dict(type='CrossEntropyLoss', loss_name='loss_1'),
+            dict(type='CrossEntropyLoss', loss_name='loss_2')
+        ])
+    assert len(point_head_multiple_losses.fcs) == 3
+    fcn_head_multiple_losses = FCNHead(
+        in_channels=32,
+        channels=16,
+        num_classes=19,
+        loss_decode=[
+            dict(type='CrossEntropyLoss', loss_name='loss_1'),
+            dict(type='CrossEntropyLoss', loss_name='loss_2')
+        ])
+    if torch.cuda.is_available():
+        head, inputs = to_cuda(point_head_multiple_losses, inputs)
+        head, inputs = to_cuda(fcn_head_multiple_losses, inputs)
+    prev_output = fcn_head_multiple_losses(inputs)
+    test_cfg = ConfigDict(
+        subdivision_steps=2, subdivision_num_points=8196, scale_factor=2)
+    output = point_head_multiple_losses.forward_test(inputs, prev_output, None,
+                                                     test_cfg)
+    assert output.shape == (1, point_head.num_classes, 180, 180)
+
+    fake_label = torch.ones([1, 180, 180], dtype=torch.long)
+
+    if torch.cuda.is_available():
+        fake_label = fake_label.cuda()
+    loss = point_head_multiple_losses.losses(output, fake_label)
+    assert 'pointloss_1' in loss
+    assert 'pointloss_2' in loss
diff --git a/tests/test_sampler.py b/tests/test_sampler.py
index 8e613a5a1f..14092243f5 100644
--- a/tests/test_sampler.py
+++ b/tests/test_sampler.py
@@ -10,6 +10,17 @@ def _context_for_ohem():
     return FCNHead(in_channels=32, channels=16, num_classes=19)
 
 
+def _context_for_ohem_multiple_loss():
+    return FCNHead(
+        in_channels=32,
+        channels=16,
+        num_classes=19,
+        loss_decode=[
+            dict(type='CrossEntropyLoss', loss_name='loss_1'),
+            dict(type='CrossEntropyLoss', loss_name='loss_2')
+        ])
+
+
 def test_ohem_sampler():
 
     with pytest.raises(AssertionError):
@@ -37,3 +48,31 @@ def test_ohem_sampler():
     assert seg_weight.shape[0] == seg_logit.shape[0]
     assert seg_weight.shape[1:] == seg_logit.shape[2:]
     assert seg_weight.sum() == 200
+
+    # test multiple losses case
+    with pytest.raises(AssertionError):
+        # seg_logit and seg_label must be of the same size
+        sampler = OHEMPixelSampler(context=_context_for_ohem_multiple_loss())
+        seg_logit = torch.randn(1, 19, 45, 45)
+        seg_label = torch.randint(0, 19, size=(1, 1, 89, 89))
+        sampler.sample(seg_logit, seg_label)
+
+    # test with thresh in multiple losses case
+    sampler = OHEMPixelSampler(
+        context=_context_for_ohem_multiple_loss(), thresh=0.7, min_kept=200)
+    seg_logit = torch.randn(1, 19, 45, 45)
+    seg_label = torch.randint(0, 19, size=(1, 1, 45, 45))
+    seg_weight = sampler.sample(seg_logit, seg_label)
+    assert seg_weight.shape[0] == seg_logit.shape[0]
+    assert seg_weight.shape[1:] == seg_logit.shape[2:]
+    assert seg_weight.sum() > 200
+
+    # test w.o thresh in multiple losses case
+    sampler = OHEMPixelSampler(
+        context=_context_for_ohem_multiple_loss(), min_kept=200)
+    seg_logit = torch.randn(1, 19, 45, 45)
+    seg_label = torch.randint(0, 19, size=(1, 1, 45, 45))
+    seg_weight = sampler.sample(seg_logit, seg_label)
+    assert seg_weight.shape[0] == seg_logit.shape[0]
+    assert seg_weight.shape[1:] == seg_logit.shape[2:]
+    assert seg_weight.sum() == 200