From c4cbe988d69b9c3b1cf1af8c2c883084ff3ccade Mon Sep 17 00:00:00 2001 From: F-G Fernandez Date: Mon, 12 Oct 2020 23:56:49 +0200 Subject: [PATCH 1/6] feat: Added support of Complement Cross-Entropy --- holocron/nn/functional.py | 57 ++++++++++++++++++++++++++++++++++++- holocron/nn/modules/loss.py | 25 +++++++++++++++- 2 files changed, 80 insertions(+), 2 deletions(-) diff --git a/holocron/nn/functional.py b/holocron/nn/functional.py index da8bb8cb1..b243e181f 100644 --- a/holocron/nn/functional.py +++ b/holocron/nn/functional.py @@ -10,7 +10,7 @@ __all__ = ['silu', 'mish', 'hard_mish', 'nl_relu', 'focal_loss', 'multilabel_cross_entropy', 'ls_cross_entropy', - 'norm_conv2d', 'add2d', 'dropblock2d'] + 'complement_cross_entropy', 'norm_conv2d', 'add2d', 'dropblock2d'] def silu(x): @@ -237,6 +237,61 @@ def ls_cross_entropy(x, target, weight=None, ignore_index=-100, reduction='mean' ignore_index=ignore_index, reduction=reduction) +def complement_cross_entropy(x, target, weight=None, ignore_index=-100, reduction='mean', gamma=-1): + """Implements the complement cross entropy loss from + `"Imbalanced Image Classification with Complement Cross Entropy" `_ + + Args: + x (torch.Tensor[N, K, ...]): input tensor + target (torch.Tensor[N, ...]): target tensor + weight (torch.Tensor[K], optional): manual rescaling of each class + ignore_index (int, optional): specifies target value that is ignored and do not contribute to gradient + reduction (str, optional): reduction method + gamma (float, optional): complement factor + + Returns: + torch.Tensor: loss reduced with `reduction` method + """ + + if gamma == 0: + return F.cross_entropy(x, target, weight, ignore_index=ignore_index, reduction=reduction) + + # log(P[class]) = log_softmax(score)[class] + # logpt = F.log_softmax(x, dim=1) + + pt = F.softmax(x, dim=1) + pt /= (1 - pt.transpose(0, 1).gather(0, target.unsqueeze(0)).transpose(0, 1)) + + loss = - 1 / (x.shape[1] - 1) * pt * torch.log(pt) + + # Nullify contributions to the loss + # TODO: vectorize or write CUDA extension + for class_idx in torch.unique(target): + loss[:, class_idx][target == class_idx] = 0. + + # Ignore index (set loss contribution to 0) + if ignore_index >= 0 and ignore_index < x.shape[1]: + loss[:, ignore_index] = 0. + + # Weight + if weight is not None: + # Tensor type + if weight.type() != x.data.type(): + weight = weight.type_as(x.data) + loss *= weight.view(1, -1, *([1] * (x.ndim - 2))) + + # Loss reduction + if reduction == 'sum': + loss = loss.sum() + else: + loss = loss.sum(dim=1) + if reduction == 'mean': + loss = loss.mean() + + # Smooth the labels + return F.cross_entropy(x, target, weight, ignore_index=ignore_index, reduction=reduction) + gamma * loss + + def _xcorrNd(fn, x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, normalize_slices=False, eps=1e-14): """Implements cross-correlation operation""" diff --git a/holocron/nn/modules/loss.py b/holocron/nn/modules/loss.py index f8aa5a264..868213159 100644 --- a/holocron/nn/modules/loss.py +++ b/holocron/nn/modules/loss.py @@ -8,7 +8,8 @@ import torch.nn as nn from .. import functional as F -__all__ = ['FocalLoss', 'MultiLabelCrossEntropy', 'LabelSmoothingCrossEntropy', 'MixupLoss', 'ClassBalancedWrapper'] +__all__ = ['FocalLoss', 'MultiLabelCrossEntropy', 'LabelSmoothingCrossEntropy', 'ComplementCrossEntropy', + 'MixupLoss', 'ClassBalancedWrapper'] class _Loss(nn.Module): @@ -107,6 +108,28 @@ def __repr__(self): return f"{self.__class__.__name__}(eps={self.eps}, reduction='{self.reduction}')" +class ComplementCrossEntropy(_Loss): + """Implements the complement cross entropy loss from + `"Imbalanced Image Classification with Complement Cross Entropy" `_ + + Args: + gamma (float, optional): smoothing factor + weight (torch.Tensor[K], optional): class weight for loss computation + ignore_index (int, optional): specifies target value that is ignored and do not contribute to gradient + reduction (str, optional): type of reduction to apply to the final loss + """ + + def __init__(self, gamma=-1, **kwargs): + super().__init__(**kwargs) + self.gamma = gamma + + def forward(self, x, target): + return F.complement_cross_entropy(x, target, self.weight, self.ignore_index, self.reduction, self.gamma) + + def __repr__(self): + return f"{self.__class__.__name__}(gamma={self.gamma}, reduction='{self.reduction}')" + + class MixupLoss(_Loss): """Implements a Mixup wrapper as described in `"mixup: Beyond Empirical Risk Minimization" `_ From e0000d21bd502b716fd4d8065b03e94ba78a4f17 Mon Sep 17 00:00:00 2001 From: F-G Fernandez Date: Mon, 12 Oct 2020 23:57:06 +0200 Subject: [PATCH 2/6] test: Added unittests for CCE Loss --- test/test_nn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_nn.py b/test/test_nn.py index 0f4396003..6978c6d6e 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -393,7 +393,8 @@ def do_test(self, mod_name=mod_name): loss_modules = [('FocalLoss', 'focal_loss'), - ('LabelSmoothingCrossEntropy', 'ls_cross_entropy')] + ('LabelSmoothingCrossEntropy', 'ls_cross_entropy'), + ('ComplementCrossEntropy', 'complement_cross_entropy')] for (mod_name, fn_name) in loss_modules: def do_test(self, mod_name=mod_name, fn_name=fn_name): From 262159c6ffd9a3715caa3e5bc6f0e0d293028635 Mon Sep 17 00:00:00 2001 From: F-G Fernandez Date: Mon, 12 Oct 2020 23:57:45 +0200 Subject: [PATCH 3/6] docs: Updated documentation --- docs/source/nn.functional.rst | 2 ++ docs/source/nn.rst | 2 ++ 2 files changed, 4 insertions(+) diff --git a/docs/source/nn.functional.rst b/docs/source/nn.functional.rst index cc0d61e27..a8944ddf3 100644 --- a/docs/source/nn.functional.rst +++ b/docs/source/nn.functional.rst @@ -30,6 +30,8 @@ Loss functions .. autofunction:: ls_cross_entropy +.. autofunction:: complement_cross_entropy + Convolutions ------------ diff --git a/docs/source/nn.rst b/docs/source/nn.rst index d18e1be01..700b0b8fc 100644 --- a/docs/source/nn.rst +++ b/docs/source/nn.rst @@ -28,6 +28,8 @@ Loss functions .. autoclass:: LabelSmoothingCrossEntropy +.. autoclass:: ComplementCrossEntropy + Loss wrappers -------------- From ce82f49a0f117bc83f8ed2d58e66537e75551920 Mon Sep 17 00:00:00 2001 From: F-G Fernandez Date: Mon, 12 Oct 2020 23:58:45 +0200 Subject: [PATCH 4/6] docs: Updated readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index ef5630ab5..8a7834da5 100644 --- a/README.md +++ b/README.md @@ -57,7 +57,7 @@ conda install -c frgfm pylocron ##### Main features - Activation: [SiLU/Swish](https://arxiv.org/abs/1606.08415), [Mish](https://arxiv.org/abs/1908.08681), [HardMish](https://github.com/digantamisra98/H-Mish), [NLReLU](https://arxiv.org/abs/1908.03682), [FReLU](https://arxiv.org/abs/2007.11824) -- Loss: [Focal Loss](https://arxiv.org/abs/1708.02002), MultiLabelCrossEntropy, [LabelSmoothingCrossEntropy](https://arxiv.org/pdf/1706.03762.pdf), [MixupLoss](https://arxiv.org/pdf/1710.09412.pdf), [ClassBalancedWrapper](https://arxiv.org/abs/1901.05555) +- Loss: [Focal Loss](https://arxiv.org/abs/1708.02002), MultiLabelCrossEntropy, [LabelSmoothingCrossEntropy](https://arxiv.org/pdf/1706.03762.pdf), [MixupLoss](https://arxiv.org/pdf/1710.09412.pdf), [ClassBalancedWrapper](https://arxiv.org/abs/1901.05555), [ComplementCrossEntropy](https://arxiv.org/abs/2009.02189) - Convolutions: [NormConv2d](https://arxiv.org/pdf/2005.05274v2.pdf), [Add2d](https://arxiv.org/pdf/1912.13200.pdf), [SlimConv2d](https://arxiv.org/pdf/2003.07469.pdf), [PyConv2d](https://arxiv.org/abs/2006.11538) - Pooling: [BlurPool2d](https://arxiv.org/abs/1904.11486) - Regularization: [DropBlock](https://arxiv.org/abs/1810.12890) From d47d563118c7f3ee23a0642197ed274a24b043ca Mon Sep 17 00:00:00 2001 From: F-G Fernandez Date: Tue, 13 Oct 2020 10:17:46 +0200 Subject: [PATCH 5/6] fix: Fixed multi label cross entropy class weights --- holocron/nn/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/holocron/nn/functional.py b/holocron/nn/functional.py index b243e181f..f3cda91b4 100644 --- a/holocron/nn/functional.py +++ b/holocron/nn/functional.py @@ -174,7 +174,7 @@ def multilabel_cross_entropy(x, target, weight=None, ignore_index=-100, reductio # Tensor type if weight.type() != x.data.type(): weight = weight.type_as(x.data) - logpt *= weight.view(1, -1) + logpt *= weight.view(1, -1, *([1] * (x.ndim - 2))) # CE Loss loss = - target * logpt From 162bfe9b13fe04c4ea08dbea9c83d14b217862a6 Mon Sep 17 00:00:00 2001 From: F-G Fernandez Date: Tue, 13 Oct 2020 10:18:00 +0200 Subject: [PATCH 6/6] test: Expanded loss module unittests --- test/test_nn.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/test/test_nn.py b/test/test_nn.py index 6978c6d6e..e19066395 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -163,12 +163,18 @@ def _test_loss_module(self, name, fn_name, multi_label=False): target = torch.rand(x.shape) else: target = (num_classes * torch.rand(num_batches, 20, 20)).to(torch.long) - criterion = loss.__dict__[name]() - self.assertEqual(criterion(x, target).item(), - F.__dict__[fn_name](x, target).item()) - criterion = loss.__dict__[name](reduction='none') - self.assertTrue(torch.equal(criterion(x, target), - F.__dict__[fn_name](x, target, reduction='none'))) + + # Check type casting of weights + class_weights = torch.ones(num_classes, dtype=torch.float16) + ignore_index = 0 + + # Check values between function and module + for reduction in ['none', 'sum', 'mean']: + # Check type casting of weights + criterion = loss.__dict__[name](weight=class_weights, reduction=reduction, ignore_index=ignore_index) + self.assertTrue(torch.equal(criterion(x, target), + F.__dict__[fn_name](x, target, weight=class_weights, + reduction=reduction, ignore_index=ignore_index))) def test_concatdownsample2d(self):