-
Notifications
You must be signed in to change notification settings - Fork 98
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Surface loss in keras-tensorflow #14
Comments
Hey, This should actually be quite easy. The most difficult part is actually generating the distance maps from the ground truth, and fortunately this is done in NumPy. I don't know much about Keras or Tensorflow, but I assume you can either do that asynchronously in some data loader, or you could do that offline once. Assuming you have a class-encoded ground truth as an array of shape def numpy_class2one_hot(seg: np.ndarray, C: int) -> np.ndarray:
assert set(np.unique(seg)).issubset(list(range(C)))
w, h = seg.shape # type: Tuple[int, int, int]
res = np.stack([seg == c for c in range(C)], axis=0).astype(np.int32)
assert res.shape == (C, w, h)
assert np.all(res.sum(axis=0) == 1)
return res This will output a one-hot encoded image of shape Then, you could plug the results directly into the def one_hot2dist(seg: np.ndarray) -> np.ndarray:
assert one_hot(torch.Tensor(seg), axis=0)
C: int = len(seg)
res = np.zeros_like(seg)
for c in range(C):
posmask = seg[c].astype(np.bool)
if posmask.any():
negmask = ~posmask
res[c] = distance(negmask) * negmask - (distance(posmask) - 1) * posmask
return res This will compute one distance map per class, for an output shape Then, you can load that in your favorite framework, and multiply it element-wise to your network softmax probabilities (which should be of shape Hope that helps, let me know if you need any help. |
@HKervadec FIrst of all, a very nice paper! Do you have any test data and expected results when running the loss function? I wanted to try your loss on my data (different project template) but I am not sure if it's actually spitting out the correct loss. Do you have any recommendations on how to test this? Thank you, |
Could you provide the implement of surface_loss in Keras? |
Hey sorry for the slow reply. I do not have much test data/values, but I had this testing script that I made a while back that you could use as a basis #!/usr/bin/env python3.6
import unittest
import torch
import numpy as np
import utils
class TestDistMap(unittest.TestCase):
def test_closure(self):
a = np.zeros((1, 256, 256))
a[:, 50:60, :] = 1
o = utils.class2one_hot(torch.Tensor(a).type(torch.float32), C=2).numpy()
res = utils.one_hot2dist(o[0])
self.assertEqual(res.shape, (2, 256, 256))
neg = (res <= 0) * res
self.assertEqual(neg.sum(), (o * res).sum())
def test_full_coverage(self):
a = np.zeros((1, 256, 256))
a[:, 50:60, :] = 1
o = utils.class2one_hot(torch.Tensor(a).type(torch.float32), C=2).numpy()
res = utils.one_hot2dist(o[0])
self.assertEqual(res.shape, (2, 256, 256))
self.assertEqual((res[1] <= 0).sum(), a.sum())
self.assertEqual((res[1] > 0).sum(), (1 - a).sum())
def test_empty(self):
a = np.zeros((1, 256, 256))
o = utils.class2one_hot(torch.Tensor(a).type(torch.float32), C=2).numpy()
res = utils.one_hot2dist(o[0])
self.assertEqual(res.shape, (2, 256, 256))
self.assertEqual(res[1].sum(), 0)
self.assertEqual((res[0] <= 0).sum(), a.size)
def test_max_dist(self):
"""
The max dist for a box should be at the midle of the object, +-1
"""
a = np.zeros((1, 256, 256))
a[:, 1:254, 1:254] = 1
o = utils.class2one_hot(torch.Tensor(a).type(torch.float32), C=2).numpy()
res = utils.one_hot2dist(o[0])
self.assertEqual(res.shape, (2, 256, 256))
self.assertEqual(res[0].max(), 127)
self.assertEqual(np.unravel_index(res[0].argmax(), (256, 256)), (127, 127))
self.assertEqual(res[1].min(), -126)
self.assertEqual(np.unravel_index(res[1].argmin(), (256, 256)), (127, 127))
def test_border(self):
"""
Make sure the border inside the object is 0 in the distance map
"""
for l in range(3, 5):
a = np.zeros((1, 25, 25))
a[:, 3:3 + l, 3:3 + l] = 1
o = utils.class2one_hot(torch.Tensor(a).type(torch.float32), C=2).numpy()
res = utils.one_hot2dist(o[0])
self.assertEqual(res.shape, (2, 25, 25))
border = (res[1] == 0)
self.assertEqual(border.sum(), 4 * (l - 1))
if __name__ == "__main__":
unittest.main()
I'm sorry, I do not have the time to do that, not the means to test it afterwards. But I think I gave enough informations in #14 (comment) to do it. I would be willing to publish in this repository a Keras implementation that someone submit as a PR. |
@Eason270 @SouthAmericaB |
Hi, from keras import backend as K
import numpy as np
import tensorflow as tf
from scipy.ndimage import distance_transform_edt as distance
def calc_dist_map(seg):
res = np.zeros_like(seg)
posmask = seg.astype(np.bool)
if posmask.any():
negmask = ~posmask
res = distance(negmask) * negmask - (distance(posmask) - 1) * posmask
return res
def calc_dist_map_batch(y_true):
y_true_numpy = y_true.numpy()
return np.array([calc_dist_map(y)
for y in y_true_numpy]).astype(np.float32)
def surface_loss(y_true, y_pred):
y_true_dist_map = tf.py_function(func=calc_dist_map_batch,
inp=[y_true],
Tout=tf.float32)
multipled = y_pred * y_true_dist_map
return K.mean(multipled) However, this is an implementation only for the problem of binary classification (foreground and background). |
This looks great. However (I am not familiar much with Keras and Tensorflow), the distance map isn't precomputed ; meaning it has to be done when computing the loss, which I assume slow things down. Is there some scheduling / compiling process that optimize the order of the operation ? From what I recall Tensorflow build its graph then runs into a session ; a clever compiler would probably reorder those parts to avoid waiting for it. |
Hi! Thank you very much for your code! I use your surface_loss function when model.compile(....loss=[surface_loss]) Have you any idea what could be happening? Thank you very much in advance! |
@HKervadec, we considered computing distance maps before network training, but we noticed that it might be hard to associate a label (image mask) with a proper distance map during training (dataset is shuffled). So that we are doing it on the fly. class AlphaScheduler(Callback):
def init(self, alpha, update_fn):
self.alpha = alpha
self.update_fn = update_fn
def on_epoch_end(self, epoch, logs=None):
updated_alpha = self.update_fn(K.get_value(self.alpha))
alpha = K.variable(1, dtype='float32')
def update_alpha(value):
return np.clip(value - 0.01, 0.01, 1)
history = model.fit_generator(
...,
callbacks=AlphaScheduler(alpha, update_alpha)
) |
I am not quite sure I understand. How come it is harder to keep track of a tuple In any case, thanks for all the contributions ! The scheduler bit is very useful as well ; I will integrate it in the pull request as well. |
@HKervadec, thank you for sharing your improvement idea. We would have to take a closer look at whether such an implementation would be possible in our case. |
Thank you very much about your explanation!
Oh! That's true! Thank you very much! You will be doing me a great service! |
You are not limited to GDL for the regional loss ; any other can work (cross-entropy and its derivative, dice loss and its derivatives). The best one will depend on your specific application, but you can already try with others. |
There is no problem in keeping track of a tuple |
Ah yes I see, the scaling will be mostly problematic. In pytorch usually these steps are done in the dataloader, before the final transforms, so that this is not really concern. But I understand how it requires to change the structure of the code a lot if this is not the case. |
Thank you very much! def update_alpha(value): is correct? Or would be
according to your code?? |
@SouthAmericaB |
Hi,
Also i tried to reduce learning rate every 10/15 epoch. (Pictures are just an example, i'm not sure if it is best-performing weights) |
Also another question: is it possible for this loss to be ~23? Maybe my implementation is not correct? |
Hey,
It's definitely possible for big images. This is even truer in your case, where most of the generated distance map will be positive. You could try lowering the weight of the loss (i.e. dividing it by some factor), or normalize it as some other people reported #18 (comment) I will close this issue, as the whole keras/tensorflow implementation seems solved. You can open another issue for your problem. Let me know |
Hi @marcinkaczor, I am trying to follow-up on what you guys did and build a dynamic alpha scheduler balancing two separate losses. I now got the alpha value updated alright, but I am not seeing the updated alpha value propagating into the combined loss function itself. I assume the loss function cannot be written as a regular python function as then only the beginning state is compiled into the model, and any subsequent alpha value changes do not get updated? How did you solve this issue? Thanks! /cc @akamojo |
Hi @tipani86, if you define def gl_sl_wrapper(alpha):
def gl_sl(y_true, y_pred):
return alpha * generalized_dice_loss(y_true, y_pred) + (1 - alpha) * surface_loss(y_true, y_pred)
return gl_sl
model.compile(loss=gl_sl_wrapper(alpha), optimizer="...", metrics=[...]) # alpha is the same variable we pass to AlphaScheduler (we combine |
Thanks @akamojo Yes I see you used this kind of "loss function factory" structure. However, I am still struggling to see the alpha value updated after each epoch. I don't know if it's a version issue (python 2.7, tf 1.15, keras 2.2.4) but the alpha update callback doesn't seem to return anything or otherwise change anything. Sure the If I observe the value of EDIT: Nevermind, I think I figured it out. You are supposed to set that value afterwards with something like |
Yes, exactly @tipani86! I guess, we did not paste this line of code, sorry about that. |
So I finally did the merge request (thanks a lot for that), and also added the scheduler that @marcinkaczor suggested. Since the journal extension is now published and online, I'll keep updating this repo in the coming days, I've yet to add the new interesting stuff (3D computation of the boundary loss and distance maps). |
@HKervadec it's great to hear that! Looking forward to seeing updates of your repo. |
问下alpha怎么检测每次epoch之后变了没有啊?我把alpha直接放到metrics里面报错了,我也不太理解为啥alpha不写成alpha=1,而写成alpha = K.variable(1, dtype='float32'),有什么区别吗? |
class AlphaScheduler(Callback):
def __init__(self, epochs):
self.epochs = epochs
def on_epoch_end(self, epoch, logs=None):
updated_alpha = np.clip(1 - 1.0 * (epoch + 1)/self.epochs, 0.01, 1)
K.set_value(alpha, updated_alpha)
print("epoch: {}".format(epoch + 1))
print("alpha: {}".format(K.get_value(alpha))) For the actual model loss, you would use the result from loss factory |
thx! |
问下这个调整学习率的检测指标还是val_loss吗 |
Yeah this may not work well for you. Because alpha is a moving ratio between two different losses which may have very different values, the absolute combined loss value might actually go up between epochs. How I approached this problem is just use a lr scheduling method which does not depend on how loss develops. For example, a continuously decreasing lr with some minimum value that it should not go below in the end. |
I am not very familiar with the reduce loss on plateau callback. The continuous lr reduction can be written for example as follows: def scheduler(epoch):
new_lr = lr - epoch * lr/epochs
return new_lr where
Sorry, I have not worked with medical image segmentation before, so not familiar with your problem setting.
I'm afraid each case is different and I cannot comment on how your data/code performs. However, you should be aware that in its original form, this surface loss does not reduce towards 0, but towards some arbitrary negative number which is determined by the ground truth mask image. I think this issue was raised already elsewhere. So it's difficult to say how your combined loss will behave when mixed up from two individual losses that have a very different output range. What you can do is try to normalize the surface loss output so that it's always a positive number that reduces towards 0, for example. P.S. could we use English for the benefit of everybody else participating in this discussion? |
thx! |
sorry to bother you again! |
I am sorry, I don't have any experience from Unets or Pytorch. That's why I only came to this specific issue to checkout the Keras+Tensorflow implementation. I used DeeplabV3+ Keras version for my problem... Edit: I don't know if it's a problem specific thing but training for 9 epochs in general seems a bit small. In our problem we generally did 50-100 or sometimes even more epochs. For the alpha transition between surface loss and the initial "carrier loss", having more epochs will also lead to a smoother gradient change. |
thx! |
As I said earlier, I am not at all familiar with medical imaging. In my use case the object is usually relatively large. When the size of the object becomes small, deeplabv3+ also has some trouble producing good segmentation results. In your case, it seems the model is overfitting, though. Maybe your training set is too small, have you added some image augmentation to boost your training set size? |
#14 (comment) |
oh I just saw the right issues for my problem. Thank you! |
Glad that you've found it. To explicit things, the implementation of Do not hesitate to create another issue if you feel your use-case is a bit different, I'll be happy to help : ) |
Hello How should I understand this? |
Does the FAQ answer your question ?
|
Could u support full scripts for training in Keras? |
My result shows that at the beginning the boundary loss is positive(around 0.3) and after training for a long time, it turns out to be 0.0000 or -0.0000. I wonder whether it is right or not. Should it be a negative one? Thanks! (I use it in 3D UNet medical segmentation for very small objects.) @HKervadec |
Hi @HKervadec |
Unfortunately no, as I do not have the time for that, and do not know Keras either. Have you tried contacting the GDL authors ? Also, bear in mind that you can combine the boundary loss with any other loss, like cross-entropy, focal loss, or anyting really.
It might help, or it might not. The only way to find out is to try. |
It sounds good, though you might want to increase your decimal point when monitoring the results. Things that you can modify and try:
I guess the normalization of the distance map would go like this: def norm_distmap(distmap: Tensor) -> Tensor:
_m: float = torch.abs(distmap).max()
return distmap / _m |
I'm using DeepLab with the Keras implementation here of boundary loss. The loss steadily decreases, but when I stop and have the model predict, the predictions are all NaN. I was wondering if anyone else had this problem or things I might be doing wrong. |
Hi, Thanks you for this. Do I have to compute the dist_map of my predictions before passing to surface_loss function? |
As the boundary loss is merely an element-wise multiplication between two tensors, it cannot produce NaN, unless its inputs contains NaN themselves (so either in the softmax or the distance map). My guess would be that some other part of your code is unsafe (in the sense that it can produce NaN, with |
No, I think in this version the distance map is computed on the fly, so you need to send only the predicted probabilities to the loss. |
Thanks for the response! I later figured out it was something wrong with the model architecture. |
Hi can you share the keras implementation of the generalized dice loss? |
Hi!
Could you help me with the implementation of surface loss function in keras and tensorflow?
Thank you very much in advance!
My best wishes.
The text was updated successfully, but these errors were encountered: