Skip to content

Commit

Permalink
[Fix] Add avg_non_ignore in cross entropy loss (open-mmlab#1409)
Browse files Browse the repository at this point in the history
* [Fix] Add avg_non_ignore in cross entropy loss

* [Fix] Add avg_non_ignore in cross entropy loss

* add docstring

* fix ut

* fix docstring and comments

* fix

* fix bce

* fix avg_factor in BCE and add more ut

* add avg_non_ignore

* add more ut

* fix part of ut

* fix part of ut

* test avg_non_ignore would not affect ce/bce when reduction none/sum

* test avg_non_ignore would not affect ce/bce when reduction none/sum/mean

* re-organize ut

* re-organize ut

* re-organize ut

* re-organize hardcode case

* fix parts of comments

* fix another parts of comments

* fix
  • Loading branch information
MengzhangLI authored Mar 30, 2022
1 parent d1b8eae commit cc89c8d
Show file tree
Hide file tree
Showing 5 changed files with 289 additions and 36 deletions.
20 changes: 20 additions & 0 deletions docs/en/tutorials/training_tricks.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,23 @@ model = dict(
In this way, `loss_weight` and `loss_name` will be weight and name in training log of corresponding loss, respectively.

Note: If you want this loss item to be included into the backward graph, `loss_` must be the prefix of the name.

## Ignore specified label index in loss calculation

In default setting, `avg_non_ignore=False` which means each pixel counts for loss calculation although some of them belong to ignore-index labels.

For loss calculation, we support ignore index of certain label by `avg_non_ignore` and `ignore_index`. In this way, the average loss would only be calculated in non-ignored labels which may achieve better performance, and here is the [reference](https://github.com/open-mmlab/mmsegmentation/pull/1409). Here is an example config of training `unet` on `Cityscapes` dataset: in loss calculation it would ignore label 0 which is background and loss average is only calculated on non-ignore labels:

```python
_base_ = './fcn_unet_s5-d16_4x4_512x1024_160k_cityscapes.py'
model = dict(
decode_head=dict(
ignore_index=0,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, avg_non_ignore=True),
auxiliary_head=dict(
ignore_index=0,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, avg_non_ignore=True)),
))
```
25 changes: 25 additions & 0 deletions docs/zh_cn/tutorials/training_tricks.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,28 @@ model = dict(
通过这种方式,确定训练过程中损失函数的权重 `loss_weight` 和在训练日志里的名字 `loss_name`

注意: `loss_name` 的名字必须带有 `loss_` 前缀,这样它才能被包括在反传的图里。

## 在损失函数中忽略特定的 label 类别

默认设置 `avg_non_ignore=False`, 即每个像素都用来计算损失函数。尽管其中的一些像素属于需要被忽略的类别。

对于训练时损失函数的计算,我们目前支持使用 `avg_non_ignore``ignore_index` 来忽略 label 特定的类别。 这样损失函数将只在非忽略类别像素中求平均值,会获得更好的表现。这里是[相关 PR](https://github.com/open-mmlab/mmsegmentation/pull/1409)。以 `unet` 使用 `Cityscapes` 数据集训练为例,
在计算损失函数时,忽略 label 为0的背景,并且仅在不被忽略的像素上计算均值。配置文件写为:

```python
_base_ = './fcn_unet_s5-d16_4x4_512x1024_160k_cityscapes.py'
model = dict(
decode_head=dict(
ignore_index=0,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, avg_non_ignore=True),
auxiliary_head=dict(
ignore_index=0,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, avg_non_ignore=True)),
))
```

通过这种方式,确定训练过程中损失函数的权重 `loss_weight` 和在训练日志里的名字 `loss_name`

注意: `loss_name` 的名字必须带有 `loss_` 前缀,这样它才能被包括在反传的图里。
92 changes: 80 additions & 12 deletions mmseg/models/losses/cross_entropy_loss.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand All @@ -13,8 +15,31 @@ def cross_entropy(pred,
class_weight=None,
reduction='mean',
avg_factor=None,
ignore_index=-100):
"""The wrapper function for :func:`F.cross_entropy`"""
ignore_index=-100,
avg_non_ignore=False):
"""cross_entropy. The wrapper function for :func:`F.cross_entropy`
Args:
pred (torch.Tensor): The prediction with shape (N, 1).
label (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
Default: None.
class_weight (list[float], optional): The weight for each class.
Default: None.
reduction (str, optional): The method used to reduce the loss.
Options are 'none', 'mean' and 'sum'. Default: 'mean'.
avg_factor (int, optional): Average factor that is used to average
the loss. Default: None.
ignore_index (int): Specifies a target value that is ignored and
does not contribute to the input gradients. When
``avg_non_ignore `` is ``True``, and the ``reduction`` is
``''mean''``, the loss is averaged over non-ignored targets.
Defaults: -100.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
`New in version 0.23.0.`
"""

# class_weight is a manual rescaling weight given to each class.
# If given, has to be a Tensor of size C element-wise losses
loss = F.cross_entropy(
Expand All @@ -25,6 +50,11 @@ def cross_entropy(pred,
ignore_index=ignore_index)

# apply weights and do the reduction
# average loss over non-ignored elements
# pytorch's official cross_entropy average loss over non-ignored elements
# refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa
if (avg_factor is None) and avg_non_ignore and reduction == 'mean':
avg_factor = label.numel() - (label == ignore_index).sum().item()
if weight is not None:
weight = weight.float()
loss = weight_reduce_loss(
Expand All @@ -46,13 +76,14 @@ def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index):
bin_labels[inds[0], labels[valid_mask]] = 1

valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float()

if label_weights is None:
bin_label_weights = valid_mask
else:
bin_label_weights = label_weights.unsqueeze(1).expand(target_shape)
bin_label_weights *= valid_mask

return bin_labels, bin_label_weights
return bin_labels, bin_label_weights, valid_mask


def binary_cross_entropy(pred,
Expand All @@ -61,19 +92,25 @@ def binary_cross_entropy(pred,
reduction='mean',
avg_factor=None,
class_weight=None,
ignore_index=255):
ignore_index=-100,
avg_non_ignore=False,
**kwargs):
"""Calculate the binary CrossEntropy loss.
Args:
pred (torch.Tensor): The prediction with shape (N, 1).
label (torch.Tensor): The learning label of the prediction.
Note: In bce loss, label < 0 is invalid.
weight (torch.Tensor, optional): Sample-wise loss weight.
reduction (str, optional): The method used to reduce the loss.
Options are "none", "mean" and "sum".
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
class_weight (list[float], optional): The weight for each class.
ignore_index (int | None): The label index to be ignored. Default: 255
ignore_index (int): The label index to be ignored. Default: -100.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
`New in version 0.23.0.`
Returns:
torch.Tensor: The calculated loss
Expand All @@ -83,12 +120,21 @@ def binary_cross_entropy(pred,
pred.dim() == 4 and label.dim() == 3), \
'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \
'H, W], label shape [N, H, W] are supported'
label, weight = _expand_onehot_labels(label, weight, pred.shape,
ignore_index)
# `weight` returned from `_expand_onehot_labels`
# has been treated for valid (non-ignore) pixels
label, weight, valid_mask = _expand_onehot_labels(
label, weight, pred.shape, ignore_index)
else:
# should mask out the ignored elements
valid_mask = ((label >= 0) & (label != ignore_index)).float()
if weight is not None:
weight *= valid_mask
else:
weight = valid_mask
# average loss over non-ignored and valid elements
if reduction == 'mean' and avg_factor is None and avg_non_ignore:
avg_factor = valid_mask.sum().item()

# weighted element-wise losses
if weight is not None:
weight = weight.float()
loss = F.binary_cross_entropy_with_logits(
pred, label.float(), pos_weight=class_weight, reduction='none')
# do the reduction for the weighted loss
Expand All @@ -104,7 +150,8 @@ def mask_cross_entropy(pred,
reduction='mean',
avg_factor=None,
class_weight=None,
ignore_index=None):
ignore_index=None,
**kwargs):
"""Calculate the CrossEntropy loss for masks.
Args:
Expand Down Expand Up @@ -153,6 +200,9 @@ class CrossEntropyLoss(nn.Module):
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_ce'.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
`New in version 0.23.0.`
"""

def __init__(self,
Expand All @@ -161,14 +211,22 @@ def __init__(self,
reduction='mean',
class_weight=None,
loss_weight=1.0,
loss_name='loss_ce'):
loss_name='loss_ce',
avg_non_ignore=False):
super(CrossEntropyLoss, self).__init__()
assert (use_sigmoid is False) or (use_mask is False)
self.use_sigmoid = use_sigmoid
self.use_mask = use_mask
self.reduction = reduction
self.loss_weight = loss_weight
self.class_weight = get_class_weight(class_weight)
self.avg_non_ignore = avg_non_ignore
if not self.avg_non_ignore and self.reduction == 'mean':
warnings.warn(
'Default ``avg_non_ignore`` is False, if you would like to '
'ignore the certain label and average loss over non-ignore '
'labels, which is the same with PyTorch official '
'cross_entropy, set ``avg_non_ignore=True``.')

if self.use_sigmoid:
self.cls_criterion = binary_cross_entropy
Expand All @@ -178,12 +236,18 @@ def __init__(self,
self.cls_criterion = cross_entropy
self._loss_name = loss_name

def extra_repr(self):
"""Extra repr."""
s = f'avg_non_ignore={self.avg_non_ignore}'
return s

def forward(self,
cls_score,
label,
weight=None,
avg_factor=None,
reduction_override=None,
ignore_index=-100,
**kwargs):
"""Forward function."""
assert reduction_override in (None, 'none', 'mean', 'sum')
Expand All @@ -193,13 +257,16 @@ def forward(self,
class_weight = cls_score.new_tensor(self.class_weight)
else:
class_weight = None
# Note: for BCE loss, label < 0 is invalid.
loss_cls = self.loss_weight * self.cls_criterion(
cls_score,
label,
weight,
class_weight=class_weight,
reduction=reduction,
avg_factor=avg_factor,
avg_non_ignore=self.avg_non_ignore,
ignore_index=ignore_index,
**kwargs)
return loss_cls

Expand All @@ -212,6 +279,7 @@ def loss_name(self):
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
Expand Down
6 changes: 5 additions & 1 deletion mmseg/models/losses/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import mmcv
import numpy as np
import torch
import torch.nn.functional as F


Expand Down Expand Up @@ -69,7 +70,10 @@ def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
else:
# if reduction is mean, then average the loss by avg_factor
if reduction == 'mean':
loss = loss.sum() / avg_factor
# Avoid causing ZeroDivisionError when avg_factor is 0.0,
# i.e., all labels of an image belong to ignore index.
eps = torch.finfo(torch.float32).eps
loss = loss.sum() / (avg_factor + eps)
# if reduction is 'none', then do nothing, otherwise raise an error
elif reduction != 'none':
raise ValueError('avg_factor can not be used with reduction="sum"')
Expand Down
Loading

0 comments on commit cc89c8d

Please sign in to comment.