Skip to content

Generalized dice loss for multi-class segmentation #9395

Closed
@emoebel

Description

@emoebel

Hey guys, I just implemented the generalised dice loss (multi-class version of dice loss), as described in ref :
(my targets are defined as: (batch_size, image_dim1, image_dim2, image_dim3, nb_of_classes))

def generalized_dice_loss_w(y_true, y_pred): 
    # Compute weights: "the contribution of each label is corrected by the inverse of its volume"
    Ncl = y_pred.shape[-1]
    w = np.zeros((Ncl,))
    for l in range(0,Ncl): w[l] = np.sum( np.asarray(y_true[:,:,:,:,l]==1,np.int8) )
    w = 1/(w**2+0.00001)

    # Compute gen dice coef:
    numerator = y_true*y_pred
    numerator = w*K.sum(numerator,(0,1,2,3))
    numerator = K.sum(numerator)
    
    denominator = y_true+y_pred
    denominator = w*K.sum(denominator,(0,1,2,3))
    denominator = K.sum(denominator)
    
    gen_dice_coef = numerator/denominator
    
    return 1-2*gen_dice_coef

But something must be wrong. I'm working with 3D images that I have to segment for 4 classes (1 background class and 3 object classes, I have a imbalanced dataset). First odd thing: while my train loss and accuracy improve during training (and converge really fast), my validation loss/accuracy are constant trough epochs (see image). Second, when predicting on test data, only the background class is predicted: I get a constant volume.

I used the exact same data and script but with categorical cross-entropy loss and get plausible results (object classes are segmented). Which means something is wrong with my implementation. Any idea what it could be?

Plus I believe it would be usefull to the keras community to have a generalised dice loss implementation, as it seems to be used in most of recent semantic segmentation tasks (at least in the medical image community).

PS: it seems odd to me how the weights are defined; I get values around 10^-10. Anyone else has tried to implement this? I also tested my function without the weights but get same problems.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions