Skip to content
This repository has been archived by the owner on Nov 3, 2022. It is now read-only.

Segmentation metrics and losses #197

Open
wants to merge 18 commits into
base: master
Choose a base branch
from

Conversation

ahundt
Copy link
Collaborator

@ahundt ahundt commented Dec 19, 2017

metrics and losses for semantic segmentation
adapted from: https://github.com/theduynguyen/Keras-FCN

@ahundt ahundt requested review from tboquet and titu1994 and removed request for tboquet December 19, 2017 18:14
@ahundt
Copy link
Collaborator Author

ahundt commented Dec 20, 2017

@titu1994 any chance you might be willing to review this?

@fchouteau
Copy link

fchouteau commented Jan 2, 2018

Hi,

None None as H,W
The last commits seem to be handling calculation with (None,None,None,N_classes) tensors, which seem to be ideal since by definition FCN are not shape dependent (except for the requirement of having w/h input shape as a multiple of the network's stride).

However, the pixelwise accuracy seem to require computing the product of int_shape which fails with tensors with none shape.
Could this be a viable solution ?

def pixel_accuracy(y_true, y_pred):
    pred_shape = K.int_shape(y_pred)
    true_shape = K.int_shape(y_true)

    # reshape such that w and h dim are multiplied together
    y_pred_reshaped = K.reshape(y_pred, (-1, pred_shape[-1]))
    y_true_reshaped = K.reshape(y_true, (-1, true_shape[-1]))

    # correctly classified
    clf_pred = K.one_hot(K.argmax(y_pred_reshaped), num_classes=true_shape[-1])
    correct_pixels_per_class = K.cast(K.equal(clf_pred, y_true_reshaped), dtype='float32')

    return K.mean(correct_pixels_per_class,axis=0)

Tensor flattening during calculations
Also I was wondering about the reason why the tensors are "flattened" during loss and metrics calculation and reshaped afterwards (in the case of binary_crossentropy). Is it for performance reasons ? Perhaps to handle any dimension ? (IIRC K.binary_crossentropy can handle 4D tensors correctly).

It should be possible to directly do calculations on Batch_size,H,W,N_classes tensors, with only n_classes prior for one_hot encoding (but it is not required for losses for example)

For example with the jaccard calculation

def mean_intersection_over_union(y_true, y_pred):
    pred_shape = K.int_shape(y_pred)
    true_shape = K.int_shape(y_true)

    # reshape such that w and h dim are multiplied together
    y_pred_reshaped = K.reshape(y_pred, (-1, pred_shape[-1]))
    y_true_reshaped = K.reshape(y_true, (-1, true_shape[-1]))

    # correctly classified
    clf_pred = K.one_hot(K.argmax(y_pred_reshaped), num_classes=true_shape[-1])
    equal_entries = K.cast(K.equal(clf_pred, y_true_reshaped), dtype='float32') * y_true_reshaped
    print(K.int_shape(clf_pred))
    intersection = K.sum(equal_entries, axis=1)
    union_per_class = K.sum(y_true_reshaped, axis=1) + K.sum(y_pred_reshaped, axis=1)
    print(K.int_shape(intersection))
    # epsilon added to avoid dividing by zero
    iou = intersection / ((union_per_class - intersection) + K.epsilon())

    return K.mean(iou)

def mean_intersection_over_union_nr(y_true, y_pred):
    pred_shape = K.int_shape(y_pred)
    true_shape = K.int_shape(y_true)

    # reshape such that w and h dim are multiplied together
    y_pred_reshaped = y_pred
    y_true_reshaped = y_true

    # correctly classified
    clf_pred = K.one_hot(K.argmax(y_pred_reshaped), num_classes=true_shape[-1])
    equal_entries = K.cast(K.equal(clf_pred, y_true_reshaped), dtype='float32') * y_true_reshaped
    print(K.int_shape(clf_pred))
    print(K.int_shape(equal_entries))
    intersection = K.sum(equal_entries, axis=(3))
    union_per_class = K.sum(y_true_reshaped, axis=(3)) + K.sum(y_pred_reshaped, axis=(3))

    print(K.int_shape(intersection))
    # epsilon added to avoid dividing by zero
    iou = intersection / ((union_per_class - intersection) + K.epsilon())

    return K.mean(iou)

IoU when I == U == 0.
Also, one last question, in your IoU calculation, would the result of an IoU where I = 0. and U = 0. (for example in binary crossentropy where the image is fully background, or in categorical crossentropy when you want to calculate the IoU only on class > 1 to exclude background classes) would the IoU be 0. or 1. ?
Should you want to have IoU = 1. in such a case it would be better to compute IoU with the following formula: iou = (intersection + K.epsilon() ) / ((union_per_class - intersection) + K.epsilon())

Here is an example of my own regading what I mean:

def jaccard_metric(class_zero_is_background=False):
    """
    Jaccard metric decorator
    Args:
        class_zero_is_background(bool): Calculate loss on class zero or not

    Returns:
        Jaccard loss function

        Formula is IOU:
        IOU = (output*target+smooth)/(output+target-output*target+smooth)
    """

    def jaccard(output, target):
        """

        Args:
            output(tensor): Tensor of shape (batch_size,w,h,num_classes). Output of SOFTMAX Activation
            target: Tensor of shape (batch_size,w,h,num_classes). one hot encoded class matrix

        Returns:
            jaccard estimation
        """
        smooth = 1.
        if class_zero_is_background:
            output = output[:, :, :, 1:]
            target = target[:, :, :, 1:]
        output = K.clip(K.abs(output), K.epsilon(), 1. - K.epsilon())
        target = K.clip(K.abs(target), K.epsilon(), 1. - K.epsilon())

        union = K.sum(output + target, axis=(1, 2, 3))
        intersection = K.sum(output * target, axis=(1, 2, 3))

        iou = (intersection + smooth) / (union - intersection + smooth)

        return iou

    return jaccard

PS: Should my tone be inappropriate I apologize, my objective is just to understand the design principles behind your implementations, not criticize them in any fashion.

@ahundt
Copy link
Collaborator Author

ahundt commented Jan 4, 2018

@fchouteau I appreciate the review! I was hoping for one which is why I didn't merge these changes yet.
Do you think you could put your changes in a pull request to my github.com/ahundt/keras-contrib segmentation_loss branch?

That way it will be easier for me to look at the difference, and merge them into this PR if your changes turn out to be the way to go. I'll also write up a response and make a second post shortly. Thanks!

@ahundt
Copy link
Collaborator Author

ahundt commented Jan 4, 2018

Here are my first thoughts on your questions

Tensor flattening during calculations
Also I was wondering about the reason why the tensors are "flattened" during loss and metrics calculation and reshaped afterwards (in the case of binary_crossentropy). Is it for performance reasons? Perhaps to handle any dimension ? (IIRC K.binary_crossentropy can handle 4D tensors correctly).

If binary_crossentropy() in Keras does handle pixel-wise losses correctly that is a relatively new feature and some of this may have been written before that worked correctly. IIRC loss may have been implemented with use cases like TimeDistributed in mind and not been correct for pixel wise segmentation problems. Has that changed?

For the pixel_accuracy() I believe your suggested K.mean(correct_pixels_per_class,axis=0) would definitely be an improvement so long as there isn't a shape issue.

@ahundt
Copy link
Collaborator Author

ahundt commented Jan 4, 2018

Regarding jaccard aka iou, I should probably drop that in consideration of jaccard.py already being merged, unless there was some advantage to one of the other versions such as (1) avoid reshaping the output (2) ignoring the background class (as in your version).

However, in your version I'd suggest changing:

def jaccard_metric(class_zero_is_background=False):

to

def jaccard_metric(classes_to_ignore=None):

This accepts a single class integer or list of integers to ignore, since it can vary from one dataset to another, such as pascal_voc where I believe 255 is the background class.

@ahundt
Copy link
Collaborator Author

ahundt commented Jan 4, 2018

Regarding:

        iou = (intersection + smooth) / (union - intersection + smooth)

You're definitely right about adding epsilon or smooth in both the numerator and denominator. I was tweaking the loss to avoid divide-by-zero and missed that! Thanks!

@ahundt
Copy link
Collaborator Author

ahundt commented Jan 5, 2018

I'm fairly certain the reshape is still necessary as only one axis will be utilized in the loss, rather than two as needed for segmentation problems:
https://github.com/keras-team/keras/blob/master/keras/losses.py#L68

@fchouteau If you have another solution or you can explain why it is not necessary, please advise.

… that lets existing loss functions be reused
@fchouteau
Copy link

fchouteau commented Jan 8, 2018

Regarding jaccard aka iou, I should probably drop that in consideration of jaccard.py already being merged, unless there was some advantage to one of the other versions such as (1) avoid reshaping the output (2) ignoring the background class (as in your version).
However, in your version I'd suggest changing:
def jaccard_metric(class_zero_is_background=False):
This accepts a single class integer or list of integers to ignore, since it can vary from one dataset to >another, such as pascal_voc where I believe 255 is the background class

Since K.gather does not support axis, I am afraid that I do not have any idea as to how to proceed for other backend than tf.

Should the classes_to_ignore list be a contiguous list, we could exclude them by calling

if classes_to_ignore is not None:
    num_classes = K.int_shape(output)[3]
    classes_to_keep = [cls for cls in range(num_classes) if cls not in classes_to_ignore)
    output = tf.gather(output,classes_to_keep,axis=3)
    target = tf.gather(output,classes_to_keep,axis=3)

However I do not have any idea how to do it in a Keras-y way

@ahundt
Copy link
Collaborator Author

ahundt commented Jan 9, 2018

However I do not have any idea how to do it in a Keras-y way

Add a K.gather to each keras_contrib backend, but for other backends they can simply raise NotImplementedError until a user on those platforms implements them.

"""
from keras import losses
import keras.backend as K
import tensorflow as tf
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nope should be easy to remove, I got it out of all the functions but forgot to remove the import haha

adapted from: https://github.com/theduynguyen/Keras-FCN
"""
import keras.backend as K
import tensorflow as tf
Copy link
Contributor

@nzw0301 nzw0301 Jan 12, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same above

@fchouteau
Copy link

I commited some proposals to your original branch.

A side note however: The function mean_accuracy is not calculating the accuracy because you are summing only on "positive pixels per class": By calculing TP / (TP+FN) you are effectively calculating the mean recall per class.

@ahundt
Copy link
Collaborator Author

ahundt commented Jan 28, 2018

@fchouteau sorry for the delay, can I get a link to the location of what you committed?

A side note however: The function mean_accuracy is not calculating the accuracy because you are summing only on "positive pixels per class": By calculing TP / (TP+FN) you are effectively calculating the mean recall per class.

Whoops, very good catch! I should have noticed that myself...

Hmm perhaps accuracy, precision and recall should all be available.

note to incorporate later:
accuracy = (TP + TN)/(TP + TN + FP + FN)
recall = TP/(TP + FN)
precision = TP/(TP + FP)

Changes in segmentation_losses for consistency
@ahundt
Copy link
Collaborator Author

ahundt commented Jan 28, 2018

found it and merged in d49e794.

Also, it looks like keras now has keras.backend.gather(reference, indices). That should work for all backends which support it!

* master: (37 commits)
  Ignore the crf_test.py
  Moving scipy installations from conda to pip
  Adding nose dependency
  Adding globalpolling to the imports
  Fix more typos
  Fix typos on CRF docstrings
  densenet.py correct docstring location
  densenet.py small docstring correction
  densenet.py bugfixes + renames as per review in keras-team#214
  densenet.py minor fixes to DenseNetFCN calls, initial kernel 3x3
  densenet.py correct transition pooling defaults to correspond with their respective papers
  densenet.py include_top updated to require_flatten in upstream keras
  densenet.py max pool by default, transition_pooling option, add DenseNetFCN early_transition option.
  remove some space lines to follow pep8
  add logsumexp for all backend
  import logsumexp from keras.tensorflow_backend
  add logsumexp in tensorflow_backend.py
  README.md travis badge URL -> keras-team
  jaccard.py clarify documentation
  Fix DenseNet-BC model def
  ...
@ahundt
Copy link
Collaborator Author

ahundt commented Feb 4, 2018

@fchouteau I went back to do an update, and now that I look in more detail I think accuracy is mostly correct, except that it should be renamed to categorical_accuracy. See how categorical_accuracy is implemented in keras itself:
https://github.com/keras-team/keras/blob/master/keras/metrics.py#L29

@bhack
Copy link

bhack commented Apr 8, 2018

@ahundt Some other interesting metrics in python (expecially F-boundary).

if smooth is None:
smooth = K.epsilon()
intersection = K.sum(K.abs(y_true * y_pred), axis=axis)
sum_ = K.sum(K.abs(y_true) + K.abs(y_pred), axis=axis)
jac = (intersection + smooth) / (sum_ - intersection + smooth)
return (1 - jac) * smooth

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe that you need to remove multiplication to smooth here.
So we return only 1-jac



def binary_jaccard_distance(y_true, y_pred, smooth=None, axis=-1):
return jaccard_distance(K.round(y_pred), y_pred, smooth, axis)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mistake here.
Please fix to:
return jaccard_distance(y_true, K.round(y_pred), smooth, axis)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, that's a pretty critical bug

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants