Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The mask related code in the Dice loss function is wrong #8

Open
nikolakopoulos opened this issue Jun 21, 2021 · 0 comments
Open

The mask related code in the Dice loss function is wrong #8

nikolakopoulos opened this issue Jun 21, 2021 · 0 comments

Comments

@nikolakopoulos
Copy link

nikolakopoulos commented Jun 21, 2021

Hello,

First of all, cool work! :)

Now let me get to the point:

I found the following bug in your code:

    if mask is not None: # here is the problem!! flat_input and flat_target are already made one-hot, thus the multiplication will not work!
        mask = mask.float()
        flat_input = flat_input * mask
        flat_target = flat_target * mask
    else:
        mask = torch.ones_like(target)

An easy fix is the following:

    if mask is not None:
        mask = mask.float()
        flat_input = (flat_input.t() * mask).t()
        flat_target = (flat_target.t() * mask).t()
    else:
        mask = torch.ones_like(target)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant