diff --git a/mmseg/models/losses/cross_entropy_loss.py b/mmseg/models/losses/cross_entropy_loss.py index 44798421aa..42c0790c98 100644 --- a/mmseg/models/losses/cross_entropy_loss.py +++ b/mmseg/models/losses/cross_entropy_loss.py @@ -3,7 +3,7 @@ import torch.nn.functional as F from ..builder import LOSSES -from .utils import weight_reduce_loss +from .utils import get_class_weight, weight_reduce_loss def cross_entropy(pred, @@ -146,8 +146,8 @@ class CrossEntropyLoss(nn.Module): Defaults to False. reduction (str, optional): . Defaults to 'mean'. Options are "none", "mean" and "sum". - class_weight (list[float], optional): Weight of each class. - Defaults to None. + class_weight (list[float] | str, optional): Weight of each class. If in + str format, read them from a file. Defaults to None. loss_weight (float, optional): Weight of the loss. Defaults to 1.0. """ @@ -163,7 +163,7 @@ def __init__(self, self.use_mask = use_mask self.reduction = reduction self.loss_weight = loss_weight - self.class_weight = class_weight + self.class_weight = get_class_weight(class_weight) if self.use_sigmoid: self.cls_criterion = binary_cross_entropy diff --git a/mmseg/models/losses/dice_loss.py b/mmseg/models/losses/dice_loss.py index b94ece3a28..27a77b962d 100644 --- a/mmseg/models/losses/dice_loss.py +++ b/mmseg/models/losses/dice_loss.py @@ -5,7 +5,7 @@ import torch.nn.functional as F from ..builder import LOSSES -from .utils import weighted_loss +from .utils import get_class_weight, weighted_loss @weighted_loss @@ -63,8 +63,8 @@ class DiceLoss(nn.Module): reduction (str, optional): The method used to reduce the loss. Options are "none", "mean" and "sum". This parameter only works when per_image is True. Default: 'mean'. - class_weight (list[float], optional): The weight for each class. - Default: None. + class_weight (list[float] | str, optional): Weight of each class. If in + str format, read them from a file. Defaults to None. loss_weight (float, optional): Weight of the loss. Default to 1.0. ignore_index (int | None): The label index to be ignored. Default: 255. """ @@ -81,7 +81,7 @@ def __init__(self, self.smooth = smooth self.exponent = exponent self.reduction = reduction - self.class_weight = class_weight + self.class_weight = get_class_weight(class_weight) self.loss_weight = loss_weight self.ignore_index = ignore_index diff --git a/mmseg/models/losses/lovasz_loss.py b/mmseg/models/losses/lovasz_loss.py index 859a656b9f..e8df6e8307 100644 --- a/mmseg/models/losses/lovasz_loss.py +++ b/mmseg/models/losses/lovasz_loss.py @@ -8,7 +8,7 @@ import torch.nn.functional as F from ..builder import LOSSES -from .utils import weight_reduce_loss +from .utils import get_class_weight, weight_reduce_loss def lovasz_grad(gt_sorted): @@ -240,8 +240,8 @@ class LovaszLoss(nn.Module): reduction (str, optional): The method used to reduce the loss. Options are "none", "mean" and "sum". This parameter only works when per_image is True. Default: 'mean'. - class_weight (list[float], optional): The weight for each class. - Default: None. + class_weight (list[float] | str, optional): Weight of each class. If in + str format, read them from a file. Defaults to None. loss_weight (float, optional): Weight of the loss. Defaults to 1.0. """ @@ -269,7 +269,7 @@ def __init__(self, self.per_image = per_image self.reduction = reduction self.loss_weight = loss_weight - self.class_weight = class_weight + self.class_weight = get_class_weight(class_weight) def forward(self, cls_score, diff --git a/mmseg/models/losses/utils.py b/mmseg/models/losses/utils.py index a1153fa9f3..ab5876603e 100644 --- a/mmseg/models/losses/utils.py +++ b/mmseg/models/losses/utils.py @@ -1,8 +1,28 @@ import functools +import mmcv +import numpy as np import torch.nn.functional as F +def get_class_weight(class_weight): + """Get class weight for loss function. + + Args: + class_weight (list[float] | str | None): If class_weight is a str, + take it as a file name and read from it. + """ + if isinstance(class_weight, str): + # take it as a file path + if class_weight.endswith('.npy'): + class_weight = np.load(class_weight) + else: + # pkl, json or yaml + class_weight = mmcv.load(class_weight) + + return class_weight + + def reduce_loss(loss, reduction): """Reduce loss as specified. diff --git a/tests/test_models/test_losses/test_ce_loss.py b/tests/test_models/test_losses/test_ce_loss.py index 35ef84348d..9619b60a91 100644 --- a/tests/test_models/test_losses/test_ce_loss.py +++ b/tests/test_models/test_losses/test_ce_loss.py @@ -25,6 +25,34 @@ def test_ce_loss(): fake_label = torch.Tensor([1]).long() assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.)) + # test loss with class weights from file + import os + import tempfile + import mmcv + import numpy as np + tmp_file = tempfile.NamedTemporaryFile() + + mmcv.dump([0.8, 0.2], f'{tmp_file.name}.pkl', 'pkl') # from pkl file + loss_cls_cfg = dict( + type='CrossEntropyLoss', + use_sigmoid=False, + class_weight=f'{tmp_file.name}.pkl', + loss_weight=1.0) + loss_cls = build_loss(loss_cls_cfg) + assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.)) + + np.save(f'{tmp_file.name}.npy', np.array([0.8, 0.2])) # from npy file + loss_cls_cfg = dict( + type='CrossEntropyLoss', + use_sigmoid=False, + class_weight=f'{tmp_file.name}.npy', + loss_weight=1.0) + loss_cls = build_loss(loss_cls_cfg) + assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.)) + tmp_file.close() + os.remove(f'{tmp_file.name}.pkl') + os.remove(f'{tmp_file.name}.npy') + loss_cls_cfg = dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) loss_cls = build_loss(loss_cls_cfg) diff --git a/tests/test_models/test_losses/test_dice_loss.py b/tests/test_models/test_losses/test_dice_loss.py index 94b9faab71..01ded6fe74 100644 --- a/tests/test_models/test_losses/test_dice_loss.py +++ b/tests/test_models/test_losses/test_dice_loss.py @@ -16,6 +16,36 @@ def test_dice_lose(): labels = (torch.rand(8, 4, 4) * 3).long() dice_loss(logits, labels) + # test loss with class weights from file + import os + import tempfile + import mmcv + import numpy as np + tmp_file = tempfile.NamedTemporaryFile() + + mmcv.dump([1.0, 2.0, 3.0], f'{tmp_file.name}.pkl', 'pkl') # from pkl file + loss_cfg = dict( + type='DiceLoss', + reduction='none', + class_weight=f'{tmp_file.name}.pkl', + loss_weight=1.0, + ignore_index=1) + dice_loss = build_loss(loss_cfg) + dice_loss(logits, labels, ignore_index=None) + + np.save(f'{tmp_file.name}.npy', np.array([1.0, 2.0, 3.0])) # from npy file + loss_cfg = dict( + type='DiceLoss', + reduction='none', + class_weight=f'{tmp_file.name}.pkl', + loss_weight=1.0, + ignore_index=1) + dice_loss = build_loss(loss_cfg) + dice_loss(logits, labels, ignore_index=None) + tmp_file.close() + os.remove(f'{tmp_file.name}.pkl') + os.remove(f'{tmp_file.name}.npy') + # test dice loss with loss_type = 'binary' loss_cfg = dict( type='DiceLoss', diff --git a/tests/test_models/test_losses/test_lovasz_loss.py b/tests/test_models/test_losses/test_lovasz_loss.py index e11dd613fa..6fac4309a9 100644 --- a/tests/test_models/test_losses/test_lovasz_loss.py +++ b/tests/test_models/test_losses/test_lovasz_loss.py @@ -38,6 +38,36 @@ def test_lovasz_loss(): labels = (torch.rand(1, 4, 4) * 2).long() lovasz_loss(logits, labels, ignore_index=None) + # test loss with class weights from file + import os + import tempfile + import mmcv + import numpy as np + tmp_file = tempfile.NamedTemporaryFile() + + mmcv.dump([1.0, 2.0, 3.0], f'{tmp_file.name}.pkl', 'pkl') # from pkl file + loss_cfg = dict( + type='LovaszLoss', + per_image=True, + reduction='mean', + class_weight=f'{tmp_file.name}.pkl', + loss_weight=1.0) + lovasz_loss = build_loss(loss_cfg) + lovasz_loss(logits, labels, ignore_index=None) + + np.save(f'{tmp_file.name}.npy', np.array([1.0, 2.0, 3.0])) # from npy file + loss_cfg = dict( + type='LovaszLoss', + per_image=True, + reduction='mean', + class_weight=f'{tmp_file.name}.npy', + loss_weight=1.0) + lovasz_loss = build_loss(loss_cfg) + lovasz_loss(logits, labels, ignore_index=None) + tmp_file.close() + os.remove(f'{tmp_file.name}.pkl') + os.remove(f'{tmp_file.name}.npy') + # test lovasz loss with loss_type = 'binary' and per_image = False loss_cfg = dict( type='LovaszLoss',