@@ -15,7 +15,7 @@ def dice_loss(pred,
15
15
smooth = 1 ,
16
16
exponent = 2 ,
17
17
class_weight = None ,
18
- ignore_index = - 1 ):
18
+ ignore_index = 255 ):
19
19
assert pred .shape [0 ] == target .shape [0 ]
20
20
total_loss = 0
21
21
num_classes = pred .shape [1 ]
@@ -36,9 +36,9 @@ def dice_loss(pred,
36
36
@weighted_loss
37
37
def binary_dice_loss (pred , target , valid_mask , smooth = 1 , exponent = 2 , ** kwards ):
38
38
assert pred .shape [0 ] == target .shape [0 ]
39
- pred = pred .contiguous (). view (pred .shape [0 ], - 1 )
40
- target = target .contiguous (). view (target .shape [0 ], - 1 )
41
- valid_mask = valid_mask .contiguous (). view (valid_mask .shape [0 ], - 1 )
39
+ pred = pred .reshape (pred .shape [0 ], - 1 )
40
+ target = target .reshape (target .shape [0 ], - 1 )
41
+ valid_mask = valid_mask .reshape (valid_mask .shape [0 ], - 1 )
42
42
43
43
num = torch .sum (torch .mul (pred , target ) * valid_mask , dim = 1 ) * 2 + smooth
44
44
den = torch .sum (pred .pow (exponent ) + target .pow (exponent ), dim = 1 ) + smooth
@@ -70,27 +70,27 @@ class DiceLoss(nn.Module):
70
70
"""
71
71
72
72
def __init__ (self ,
73
- loss_type = 'multi_class' ,
74
73
smooth = 1 ,
75
74
exponent = 2 ,
76
75
reduction = 'mean' ,
77
76
class_weight = None ,
78
77
loss_weight = 1.0 ,
79
- ignore_index = 255 ):
78
+ ignore_index = 255 ,
79
+ ** kwards ):
80
80
super (DiceLoss , self ).__init__ ()
81
- assert loss_type in ['multi_class' , 'binary' ]
82
- if loss_type == 'multi_class' :
83
- self .cls_criterion = dice_loss
84
- else :
85
- self .cls_criterion = binary_dice_loss
86
81
self .smooth = smooth
87
82
self .exponent = exponent
88
83
self .reduction = reduction
89
84
self .class_weight = class_weight
90
85
self .loss_weight = loss_weight
91
86
self .ignore_index = ignore_index
92
87
93
- def forward (self , pred , target , avg_factor = None , reduction_override = None ):
88
+ def forward (self ,
89
+ pred ,
90
+ target ,
91
+ avg_factor = None ,
92
+ reduction_override = None ,
93
+ ** kwards ):
94
94
assert reduction_override in (None , 'none' , 'mean' , 'sum' )
95
95
reduction = (
96
96
reduction_override if reduction_override else self .reduction )
@@ -100,10 +100,13 @@ def forward(self, pred, target, avg_factor=None, reduction_override=None):
100
100
class_weight = None
101
101
102
102
pred = F .softmax (pred , dim = 1 )
103
- one_hot_target = F .one_hot (torch .clamp_min (target .long (), 0 ))
103
+ num_classes = pred .shape [1 ]
104
+ one_hot_target = F .one_hot (
105
+ torch .clamp (target .long (), 0 , num_classes - 1 ),
106
+ num_classes = num_classes )
104
107
valid_mask = (target != self .ignore_index ).long ()
105
108
106
- loss = self .loss_weight * self . cls_criterion (
109
+ loss = self .loss_weight * dice_loss (
107
110
pred ,
108
111
one_hot_target ,
109
112
valid_mask = valid_mask ,
0 commit comments