Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dice loss #396

Merged
merged 4 commits into from
Mar 11, 2021
Merged

dice loss #396

merged 4 commits into from
Mar 11, 2021

Conversation

xiexinch
Copy link
Collaborator

@xiexinch xiexinch commented Mar 2, 2021

No description provided.

@CLAassistant
Copy link

CLAassistant commented Mar 2, 2021

CLA assistant check
All committers have signed the CLA.

@codecov
Copy link

codecov bot commented Mar 2, 2021

Codecov Report

Merging #396 (28ae84f) into master (d0a71c1) will increase coverage by 0.12%.
The diff coverage is 86.56%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #396      +/-   ##
==========================================
+ Coverage   86.20%   86.32%   +0.12%     
==========================================
  Files          96       98       +2     
  Lines        4906     4973      +67     
  Branches      799      808       +9     
==========================================
+ Hits         4229     4293      +64     
- Misses        523      525       +2     
- Partials      154      155       +1     
Flag Coverage Δ
unittests 86.32% <86.56%> (+0.12%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
mmseg/models/decode_heads/fcn_dilate_head.py 42.85% <42.85%> (ø)
mmseg/models/losses/dice_loss.py 98.03% <98.03%> (ø)
mmseg/models/decode_heads/__init__.py 100.00% <100.00%> (ø)
mmseg/models/losses/__init__.py 100.00% <100.00%> (ø)
mmseg/models/losses/utils.py 76.66% <0.00%> (+19.99%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update d0a71c1...20c9f62. Read the comment docs.

@xvjiarui
Copy link
Collaborator

xvjiarui commented Mar 2, 2021

Comment on lines 45 to 48
class DiceLoss(nn.Module):
"""DiceLoss.

"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may add some docstring here.

valid_mask = valid_mask.contiguous().view(valid_mask.shape[0], -1)

num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + smooth
den = torch.sum((pred.pow(exponent) + target.pow(exponent)) * valid_mask, dim=1) + smooth
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may directly use denominator.

class_weight (list[float], optional): The weight for each class.
Default: None.
loss_weight (float, optional): Weight of the loss. Default to 1.0.
ignore_index (int | None): The label index to be ignored. Default: -1.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be 255.

@xvjiarui
Copy link
Collaborator

FCN dilate model shouldn't be in this PR

@xvjiarui xvjiarui merged commit 7e1b24d into open-mmlab:master Mar 11, 2021
else:
class_weight = None

pred = F.softmax(pred, dim=1)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In instance segmentation, dice loss may use sigmoid for activation. Suggest supporting both cases.

assert pred.shape[0] == target.shape[0]
total_loss = 0
num_classes = pred.shape[1]
for i in range(num_classes):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use for loop might be inefficient? Some implementation support to process multi-class in a batched manner.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@weighted_loss
def binary_dice_loss(pred, target, valid_mask, smooth=1, exponent=2, **kwards):
assert pred.shape[0] == target.shape[0]
pred = pred.contiguous().view(pred.shape[0], -1)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.contiguous().view() can be replaced by reshape?

bowenroom pushed a commit to bowenroom/mmsegmentation that referenced this pull request Feb 25, 2022
* dice loss

* format code, add docstring and calculate denominator without valid_mask

* minor change

* restore
aravind-h-v pushed a commit to aravind-h-v/mmsegmentation that referenced this pull request Mar 27, 2023
* ddim docs

for issue open-mmlab#293

* space
wjkim81 pushed a commit to wjkim81/mmsegmentation that referenced this pull request Dec 3, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants