Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added support of Complement CE loss #90

Merged
merged 6 commits into from
Oct 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ conda install -c frgfm pylocron
##### Main features

- Activation: [SiLU/Swish](https://arxiv.org/abs/1606.08415), [Mish](https://arxiv.org/abs/1908.08681), [HardMish](https://github.com/digantamisra98/H-Mish), [NLReLU](https://arxiv.org/abs/1908.03682), [FReLU](https://arxiv.org/abs/2007.11824)
- Loss: [Focal Loss](https://arxiv.org/abs/1708.02002), MultiLabelCrossEntropy, [LabelSmoothingCrossEntropy](https://arxiv.org/pdf/1706.03762.pdf), [MixupLoss](https://arxiv.org/pdf/1710.09412.pdf), [ClassBalancedWrapper](https://arxiv.org/abs/1901.05555)
- Loss: [Focal Loss](https://arxiv.org/abs/1708.02002), MultiLabelCrossEntropy, [LabelSmoothingCrossEntropy](https://arxiv.org/pdf/1706.03762.pdf), [MixupLoss](https://arxiv.org/pdf/1710.09412.pdf), [ClassBalancedWrapper](https://arxiv.org/abs/1901.05555), [ComplementCrossEntropy](https://arxiv.org/abs/2009.02189)
- Convolutions: [NormConv2d](https://arxiv.org/pdf/2005.05274v2.pdf), [Add2d](https://arxiv.org/pdf/1912.13200.pdf), [SlimConv2d](https://arxiv.org/pdf/2003.07469.pdf), [PyConv2d](https://arxiv.org/abs/2006.11538)
- Pooling: [BlurPool2d](https://arxiv.org/abs/1904.11486)
- Regularization: [DropBlock](https://arxiv.org/abs/1810.12890)
Expand Down
2 changes: 2 additions & 0 deletions docs/source/nn.functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ Loss functions

.. autofunction:: ls_cross_entropy

.. autofunction:: complement_cross_entropy

Convolutions
------------

Expand Down
2 changes: 2 additions & 0 deletions docs/source/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ Loss functions

.. autoclass:: LabelSmoothingCrossEntropy

.. autoclass:: ComplementCrossEntropy

Loss wrappers
--------------

Expand Down
59 changes: 57 additions & 2 deletions holocron/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


__all__ = ['silu', 'mish', 'hard_mish', 'nl_relu', 'focal_loss', 'multilabel_cross_entropy', 'ls_cross_entropy',
'norm_conv2d', 'add2d', 'dropblock2d']
'complement_cross_entropy', 'norm_conv2d', 'add2d', 'dropblock2d']


def silu(x):
Expand Down Expand Up @@ -174,7 +174,7 @@ def multilabel_cross_entropy(x, target, weight=None, ignore_index=-100, reductio
# Tensor type
if weight.type() != x.data.type():
weight = weight.type_as(x.data)
logpt *= weight.view(1, -1)
logpt *= weight.view(1, -1, *([1] * (x.ndim - 2)))

# CE Loss
loss = - target * logpt
Expand Down Expand Up @@ -237,6 +237,61 @@ def ls_cross_entropy(x, target, weight=None, ignore_index=-100, reduction='mean'
ignore_index=ignore_index, reduction=reduction)


def complement_cross_entropy(x, target, weight=None, ignore_index=-100, reduction='mean', gamma=-1):
"""Implements the complement cross entropy loss from
`"Imbalanced Image Classification with Complement Cross Entropy" <https://arxiv.org/pdf/2009.02189.pdf>`_

Args:
x (torch.Tensor[N, K, ...]): input tensor
target (torch.Tensor[N, ...]): target tensor
weight (torch.Tensor[K], optional): manual rescaling of each class
ignore_index (int, optional): specifies target value that is ignored and do not contribute to gradient
reduction (str, optional): reduction method
gamma (float, optional): complement factor

Returns:
torch.Tensor: loss reduced with `reduction` method
"""

if gamma == 0:
return F.cross_entropy(x, target, weight, ignore_index=ignore_index, reduction=reduction)

# log(P[class]) = log_softmax(score)[class]
# logpt = F.log_softmax(x, dim=1)

pt = F.softmax(x, dim=1)
pt /= (1 - pt.transpose(0, 1).gather(0, target.unsqueeze(0)).transpose(0, 1))

loss = - 1 / (x.shape[1] - 1) * pt * torch.log(pt)

# Nullify contributions to the loss
# TODO: vectorize or write CUDA extension
for class_idx in torch.unique(target):
loss[:, class_idx][target == class_idx] = 0.

# Ignore index (set loss contribution to 0)
if ignore_index >= 0 and ignore_index < x.shape[1]:
loss[:, ignore_index] = 0.

# Weight
if weight is not None:
# Tensor type
if weight.type() != x.data.type():
weight = weight.type_as(x.data)
loss *= weight.view(1, -1, *([1] * (x.ndim - 2)))

# Loss reduction
if reduction == 'sum':
loss = loss.sum()
else:
loss = loss.sum(dim=1)
if reduction == 'mean':
loss = loss.mean()

# Smooth the labels
return F.cross_entropy(x, target, weight, ignore_index=ignore_index, reduction=reduction) + gamma * loss


def _xcorrNd(fn, x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1,
normalize_slices=False, eps=1e-14):
"""Implements cross-correlation operation"""
Expand Down
25 changes: 24 additions & 1 deletion holocron/nn/modules/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import torch.nn as nn
from .. import functional as F

__all__ = ['FocalLoss', 'MultiLabelCrossEntropy', 'LabelSmoothingCrossEntropy', 'MixupLoss', 'ClassBalancedWrapper']
__all__ = ['FocalLoss', 'MultiLabelCrossEntropy', 'LabelSmoothingCrossEntropy', 'ComplementCrossEntropy',
'MixupLoss', 'ClassBalancedWrapper']


class _Loss(nn.Module):
Expand Down Expand Up @@ -107,6 +108,28 @@ def __repr__(self):
return f"{self.__class__.__name__}(eps={self.eps}, reduction='{self.reduction}')"


class ComplementCrossEntropy(_Loss):
"""Implements the complement cross entropy loss from
`"Imbalanced Image Classification with Complement Cross Entropy" <https://arxiv.org/pdf/2009.02189.pdf>`_

Args:
gamma (float, optional): smoothing factor
weight (torch.Tensor[K], optional): class weight for loss computation
ignore_index (int, optional): specifies target value that is ignored and do not contribute to gradient
reduction (str, optional): type of reduction to apply to the final loss
"""

def __init__(self, gamma=-1, **kwargs):
super().__init__(**kwargs)
self.gamma = gamma

def forward(self, x, target):
return F.complement_cross_entropy(x, target, self.weight, self.ignore_index, self.reduction, self.gamma)

def __repr__(self):
return f"{self.__class__.__name__}(gamma={self.gamma}, reduction='{self.reduction}')"


class MixupLoss(_Loss):
"""Implements a Mixup wrapper as described in
`"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/pdf/1710.09412.pdf>`_
Expand Down
21 changes: 14 additions & 7 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,12 +163,18 @@ def _test_loss_module(self, name, fn_name, multi_label=False):
target = torch.rand(x.shape)
else:
target = (num_classes * torch.rand(num_batches, 20, 20)).to(torch.long)
criterion = loss.__dict__[name]()
self.assertEqual(criterion(x, target).item(),
F.__dict__[fn_name](x, target).item())
criterion = loss.__dict__[name](reduction='none')
self.assertTrue(torch.equal(criterion(x, target),
F.__dict__[fn_name](x, target, reduction='none')))

# Check type casting of weights
class_weights = torch.ones(num_classes, dtype=torch.float16)
ignore_index = 0

# Check values between function and module
for reduction in ['none', 'sum', 'mean']:
# Check type casting of weights
criterion = loss.__dict__[name](weight=class_weights, reduction=reduction, ignore_index=ignore_index)
self.assertTrue(torch.equal(criterion(x, target),
F.__dict__[fn_name](x, target, weight=class_weights,
reduction=reduction, ignore_index=ignore_index)))

def test_concatdownsample2d(self):

Expand Down Expand Up @@ -393,7 +399,8 @@ def do_test(self, mod_name=mod_name):


loss_modules = [('FocalLoss', 'focal_loss'),
('LabelSmoothingCrossEntropy', 'ls_cross_entropy')]
('LabelSmoothingCrossEntropy', 'ls_cross_entropy'),
('ComplementCrossEntropy', 'complement_cross_entropy')]

for (mod_name, fn_name) in loss_modules:
def do_test(self, mod_name=mod_name, fn_name=fn_name):
Expand Down