Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhance] Support reading class_weight from file in loss functions to help MMDet3D #513

Merged
merged 4 commits into from
Apr 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions mmseg/models/losses/cross_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch.nn.functional as F

from ..builder import LOSSES
from .utils import weight_reduce_loss
from .utils import get_class_weight, weight_reduce_loss


def cross_entropy(pred,
Expand Down Expand Up @@ -146,8 +146,8 @@ class CrossEntropyLoss(nn.Module):
Defaults to False.
reduction (str, optional): . Defaults to 'mean'.
Options are "none", "mean" and "sum".
class_weight (list[float], optional): Weight of each class.
Defaults to None.
class_weight (list[float] | str, optional): Weight of each class. If in
str format, read them from a file. Defaults to None.
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
"""

Expand All @@ -163,7 +163,7 @@ def __init__(self,
self.use_mask = use_mask
self.reduction = reduction
self.loss_weight = loss_weight
self.class_weight = class_weight
self.class_weight = get_class_weight(class_weight)

if self.use_sigmoid:
self.cls_criterion = binary_cross_entropy
Expand Down
8 changes: 4 additions & 4 deletions mmseg/models/losses/dice_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch.nn.functional as F

from ..builder import LOSSES
from .utils import weighted_loss
from .utils import get_class_weight, weighted_loss


@weighted_loss
Expand Down Expand Up @@ -63,8 +63,8 @@ class DiceLoss(nn.Module):
reduction (str, optional): The method used to reduce the loss. Options
are "none", "mean" and "sum". This parameter only works when
per_image is True. Default: 'mean'.
class_weight (list[float], optional): The weight for each class.
Default: None.
class_weight (list[float] | str, optional): Weight of each class. If in
str format, read them from a file. Defaults to None.
loss_weight (float, optional): Weight of the loss. Default to 1.0.
ignore_index (int | None): The label index to be ignored. Default: 255.
"""
Expand All @@ -81,7 +81,7 @@ def __init__(self,
self.smooth = smooth
self.exponent = exponent
self.reduction = reduction
self.class_weight = class_weight
self.class_weight = get_class_weight(class_weight)
self.loss_weight = loss_weight
self.ignore_index = ignore_index

Expand Down
8 changes: 4 additions & 4 deletions mmseg/models/losses/lovasz_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch.nn.functional as F

from ..builder import LOSSES
from .utils import weight_reduce_loss
from .utils import get_class_weight, weight_reduce_loss


def lovasz_grad(gt_sorted):
Expand Down Expand Up @@ -240,8 +240,8 @@ class LovaszLoss(nn.Module):
reduction (str, optional): The method used to reduce the loss. Options
are "none", "mean" and "sum". This parameter only works when
per_image is True. Default: 'mean'.
class_weight (list[float], optional): The weight for each class.
Default: None.
class_weight (list[float] | str, optional): Weight of each class. If in
str format, read them from a file. Defaults to None.
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
"""

Expand Down Expand Up @@ -269,7 +269,7 @@ def __init__(self,
self.per_image = per_image
self.reduction = reduction
self.loss_weight = loss_weight
self.class_weight = class_weight
self.class_weight = get_class_weight(class_weight)

def forward(self,
cls_score,
Expand Down
20 changes: 20 additions & 0 deletions mmseg/models/losses/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,28 @@
import functools

import mmcv
import numpy as np
import torch.nn.functional as F


def get_class_weight(class_weight):
"""Get class weight for loss function.

Args:
class_weight (list[float] | str | None): If class_weight is a str,
take it as a file name and read from it.
"""
if isinstance(class_weight, str):
# take it as a file path
if class_weight.endswith('.npy'):
class_weight = np.load(class_weight)
else:
# pkl, json or yaml
class_weight = mmcv.load(class_weight)

return class_weight


def reduce_loss(loss, reduction):
"""Reduce loss as specified.

Expand Down
28 changes: 28 additions & 0 deletions tests/test_models/test_losses/test_ce_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,34 @@ def test_ce_loss():
fake_label = torch.Tensor([1]).long()
assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.))

# test loss with class weights from file
import os
import tempfile
import mmcv
import numpy as np
tmp_file = tempfile.NamedTemporaryFile()

mmcv.dump([0.8, 0.2], f'{tmp_file.name}.pkl', 'pkl') # from pkl file
loss_cls_cfg = dict(
type='CrossEntropyLoss',
use_sigmoid=False,
class_weight=f'{tmp_file.name}.pkl',
loss_weight=1.0)
loss_cls = build_loss(loss_cls_cfg)
assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.))

np.save(f'{tmp_file.name}.npy', np.array([0.8, 0.2])) # from npy file
loss_cls_cfg = dict(
type='CrossEntropyLoss',
use_sigmoid=False,
class_weight=f'{tmp_file.name}.npy',
loss_weight=1.0)
loss_cls = build_loss(loss_cls_cfg)
assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.))
tmp_file.close()
os.remove(f'{tmp_file.name}.pkl')
os.remove(f'{tmp_file.name}.npy')

loss_cls_cfg = dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
loss_cls = build_loss(loss_cls_cfg)
Expand Down
30 changes: 30 additions & 0 deletions tests/test_models/test_losses/test_dice_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,36 @@ def test_dice_lose():
labels = (torch.rand(8, 4, 4) * 3).long()
dice_loss(logits, labels)

# test loss with class weights from file
import os
import tempfile
import mmcv
import numpy as np
tmp_file = tempfile.NamedTemporaryFile()

mmcv.dump([1.0, 2.0, 3.0], f'{tmp_file.name}.pkl', 'pkl') # from pkl file
loss_cfg = dict(
type='DiceLoss',
reduction='none',
class_weight=f'{tmp_file.name}.pkl',
loss_weight=1.0,
ignore_index=1)
dice_loss = build_loss(loss_cfg)
dice_loss(logits, labels, ignore_index=None)

np.save(f'{tmp_file.name}.npy', np.array([1.0, 2.0, 3.0])) # from npy file
loss_cfg = dict(
type='DiceLoss',
reduction='none',
class_weight=f'{tmp_file.name}.pkl',
loss_weight=1.0,
ignore_index=1)
dice_loss = build_loss(loss_cfg)
dice_loss(logits, labels, ignore_index=None)
tmp_file.close()
os.remove(f'{tmp_file.name}.pkl')
os.remove(f'{tmp_file.name}.npy')

# test dice loss with loss_type = 'binary'
loss_cfg = dict(
type='DiceLoss',
Expand Down
30 changes: 30 additions & 0 deletions tests/test_models/test_losses/test_lovasz_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,36 @@ def test_lovasz_loss():
labels = (torch.rand(1, 4, 4) * 2).long()
lovasz_loss(logits, labels, ignore_index=None)

# test loss with class weights from file
import os
import tempfile
import mmcv
import numpy as np
tmp_file = tempfile.NamedTemporaryFile()

mmcv.dump([1.0, 2.0, 3.0], f'{tmp_file.name}.pkl', 'pkl') # from pkl file
loss_cfg = dict(
type='LovaszLoss',
per_image=True,
reduction='mean',
class_weight=f'{tmp_file.name}.pkl',
loss_weight=1.0)
lovasz_loss = build_loss(loss_cfg)
lovasz_loss(logits, labels, ignore_index=None)

np.save(f'{tmp_file.name}.npy', np.array([1.0, 2.0, 3.0])) # from npy file
loss_cfg = dict(
type='LovaszLoss',
per_image=True,
reduction='mean',
class_weight=f'{tmp_file.name}.npy',
loss_weight=1.0)
lovasz_loss = build_loss(loss_cfg)
lovasz_loss(logits, labels, ignore_index=None)
tmp_file.close()
os.remove(f'{tmp_file.name}.pkl')
os.remove(f'{tmp_file.name}.npy')

# test lovasz loss with loss_type = 'binary' and per_image = False
loss_cfg = dict(
type='LovaszLoss',
Expand Down