Skip to content

Commit

Permalink
[Enhance] Support reading class_weight from file in loss functions to…
Browse files Browse the repository at this point in the history
… help MMDet3D (open-mmlab#513)

* support reading class_weight from file in loss function

* add unit test of loss with class_weight from file

* minor fix

* move get_class_weight to utils
  • Loading branch information
Wuziyi616 authored Apr 29, 2021
1 parent ce56e68 commit 771ca7d
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 12 deletions.
8 changes: 4 additions & 4 deletions mmseg/models/losses/cross_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch.nn.functional as F

from ..builder import LOSSES
from .utils import weight_reduce_loss
from .utils import get_class_weight, weight_reduce_loss


def cross_entropy(pred,
Expand Down Expand Up @@ -146,8 +146,8 @@ class CrossEntropyLoss(nn.Module):
Defaults to False.
reduction (str, optional): . Defaults to 'mean'.
Options are "none", "mean" and "sum".
class_weight (list[float], optional): Weight of each class.
Defaults to None.
class_weight (list[float] | str, optional): Weight of each class. If in
str format, read them from a file. Defaults to None.
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
"""

Expand All @@ -163,7 +163,7 @@ def __init__(self,
self.use_mask = use_mask
self.reduction = reduction
self.loss_weight = loss_weight
self.class_weight = class_weight
self.class_weight = get_class_weight(class_weight)

if self.use_sigmoid:
self.cls_criterion = binary_cross_entropy
Expand Down
8 changes: 4 additions & 4 deletions mmseg/models/losses/dice_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch.nn.functional as F

from ..builder import LOSSES
from .utils import weighted_loss
from .utils import get_class_weight, weighted_loss


@weighted_loss
Expand Down Expand Up @@ -63,8 +63,8 @@ class DiceLoss(nn.Module):
reduction (str, optional): The method used to reduce the loss. Options
are "none", "mean" and "sum". This parameter only works when
per_image is True. Default: 'mean'.
class_weight (list[float], optional): The weight for each class.
Default: None.
class_weight (list[float] | str, optional): Weight of each class. If in
str format, read them from a file. Defaults to None.
loss_weight (float, optional): Weight of the loss. Default to 1.0.
ignore_index (int | None): The label index to be ignored. Default: 255.
"""
Expand All @@ -81,7 +81,7 @@ def __init__(self,
self.smooth = smooth
self.exponent = exponent
self.reduction = reduction
self.class_weight = class_weight
self.class_weight = get_class_weight(class_weight)
self.loss_weight = loss_weight
self.ignore_index = ignore_index

Expand Down
8 changes: 4 additions & 4 deletions mmseg/models/losses/lovasz_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch.nn.functional as F

from ..builder import LOSSES
from .utils import weight_reduce_loss
from .utils import get_class_weight, weight_reduce_loss


def lovasz_grad(gt_sorted):
Expand Down Expand Up @@ -240,8 +240,8 @@ class LovaszLoss(nn.Module):
reduction (str, optional): The method used to reduce the loss. Options
are "none", "mean" and "sum". This parameter only works when
per_image is True. Default: 'mean'.
class_weight (list[float], optional): The weight for each class.
Default: None.
class_weight (list[float] | str, optional): Weight of each class. If in
str format, read them from a file. Defaults to None.
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
"""

Expand Down Expand Up @@ -269,7 +269,7 @@ def __init__(self,
self.per_image = per_image
self.reduction = reduction
self.loss_weight = loss_weight
self.class_weight = class_weight
self.class_weight = get_class_weight(class_weight)

def forward(self,
cls_score,
Expand Down
20 changes: 20 additions & 0 deletions mmseg/models/losses/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,28 @@
import functools

import mmcv
import numpy as np
import torch.nn.functional as F


def get_class_weight(class_weight):
"""Get class weight for loss function.
Args:
class_weight (list[float] | str | None): If class_weight is a str,
take it as a file name and read from it.
"""
if isinstance(class_weight, str):
# take it as a file path
if class_weight.endswith('.npy'):
class_weight = np.load(class_weight)
else:
# pkl, json or yaml
class_weight = mmcv.load(class_weight)

return class_weight


def reduce_loss(loss, reduction):
"""Reduce loss as specified.
Expand Down
28 changes: 28 additions & 0 deletions tests/test_models/test_losses/test_ce_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,34 @@ def test_ce_loss():
fake_label = torch.Tensor([1]).long()
assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.))

# test loss with class weights from file
import os
import tempfile
import mmcv
import numpy as np
tmp_file = tempfile.NamedTemporaryFile()

mmcv.dump([0.8, 0.2], f'{tmp_file.name}.pkl', 'pkl') # from pkl file
loss_cls_cfg = dict(
type='CrossEntropyLoss',
use_sigmoid=False,
class_weight=f'{tmp_file.name}.pkl',
loss_weight=1.0)
loss_cls = build_loss(loss_cls_cfg)
assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.))

np.save(f'{tmp_file.name}.npy', np.array([0.8, 0.2])) # from npy file
loss_cls_cfg = dict(
type='CrossEntropyLoss',
use_sigmoid=False,
class_weight=f'{tmp_file.name}.npy',
loss_weight=1.0)
loss_cls = build_loss(loss_cls_cfg)
assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.))
tmp_file.close()
os.remove(f'{tmp_file.name}.pkl')
os.remove(f'{tmp_file.name}.npy')

loss_cls_cfg = dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
loss_cls = build_loss(loss_cls_cfg)
Expand Down
30 changes: 30 additions & 0 deletions tests/test_models/test_losses/test_dice_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,36 @@ def test_dice_lose():
labels = (torch.rand(8, 4, 4) * 3).long()
dice_loss(logits, labels)

# test loss with class weights from file
import os
import tempfile
import mmcv
import numpy as np
tmp_file = tempfile.NamedTemporaryFile()

mmcv.dump([1.0, 2.0, 3.0], f'{tmp_file.name}.pkl', 'pkl') # from pkl file
loss_cfg = dict(
type='DiceLoss',
reduction='none',
class_weight=f'{tmp_file.name}.pkl',
loss_weight=1.0,
ignore_index=1)
dice_loss = build_loss(loss_cfg)
dice_loss(logits, labels, ignore_index=None)

np.save(f'{tmp_file.name}.npy', np.array([1.0, 2.0, 3.0])) # from npy file
loss_cfg = dict(
type='DiceLoss',
reduction='none',
class_weight=f'{tmp_file.name}.pkl',
loss_weight=1.0,
ignore_index=1)
dice_loss = build_loss(loss_cfg)
dice_loss(logits, labels, ignore_index=None)
tmp_file.close()
os.remove(f'{tmp_file.name}.pkl')
os.remove(f'{tmp_file.name}.npy')

# test dice loss with loss_type = 'binary'
loss_cfg = dict(
type='DiceLoss',
Expand Down
30 changes: 30 additions & 0 deletions tests/test_models/test_losses/test_lovasz_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,36 @@ def test_lovasz_loss():
labels = (torch.rand(1, 4, 4) * 2).long()
lovasz_loss(logits, labels, ignore_index=None)

# test loss with class weights from file
import os
import tempfile
import mmcv
import numpy as np
tmp_file = tempfile.NamedTemporaryFile()

mmcv.dump([1.0, 2.0, 3.0], f'{tmp_file.name}.pkl', 'pkl') # from pkl file
loss_cfg = dict(
type='LovaszLoss',
per_image=True,
reduction='mean',
class_weight=f'{tmp_file.name}.pkl',
loss_weight=1.0)
lovasz_loss = build_loss(loss_cfg)
lovasz_loss(logits, labels, ignore_index=None)

np.save(f'{tmp_file.name}.npy', np.array([1.0, 2.0, 3.0])) # from npy file
loss_cfg = dict(
type='LovaszLoss',
per_image=True,
reduction='mean',
class_weight=f'{tmp_file.name}.npy',
loss_weight=1.0)
lovasz_loss = build_loss(loss_cfg)
lovasz_loss(logits, labels, ignore_index=None)
tmp_file.close()
os.remove(f'{tmp_file.name}.pkl')
os.remove(f'{tmp_file.name}.npy')

# test lovasz loss with loss_type = 'binary' and per_image = False
loss_cfg = dict(
type='LovaszLoss',
Expand Down

0 comments on commit 771ca7d

Please sign in to comment.