Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Fix binary segmentation when num_classes==1 #2016

Merged
merged 11 commits into from
Sep 8, 2022

Conversation

xiexinch
Copy link
Collaborator

@xiexinch xiexinch commented Sep 2, 2022

Motivation

Fix issue #1774

Modification

  1. Add threshold and out_channels parameters to decode_head
  2. Fix inference logits while num_classes==1

BC-breaking (Optional)

None

@codecov
Copy link

codecov bot commented Sep 2, 2022

Codecov Report

Base: 89.08% // Head: 89.10% // Increases project coverage by +0.02% 🎉

Coverage data is based on head (e18fe43) compared to base (74e13cf).
Patch coverage: 100.00% of modified lines in pull request are covered.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #2016      +/-   ##
==========================================
+ Coverage   89.08%   89.10%   +0.02%     
==========================================
  Files         145      145              
  Lines        8691     8711      +20     
  Branches     1463     1470       +7     
==========================================
+ Hits         7742     7762      +20     
  Misses        707      707              
  Partials      242      242              
Flag Coverage Δ
unittests 89.10% <100.00%> (+0.02%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
mmseg/models/decode_heads/decode_head.py 93.45% <100.00%> (+0.82%) ⬆️
mmseg/models/segmentors/cascade_encoder_decoder.py 94.73% <100.00%> (+0.14%) ⬆️
mmseg/models/segmentors/encoder_decoder.py 89.17% <100.00%> (+0.50%) ⬆️

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report at Codecov.
📢 Do you have feedback about the report comment? Let us know in this issue.

@@ -245,7 +245,10 @@ def inference(self, img, img_meta, rescale):
seg_logit = self.slide_inference(img, img_meta, rescale)
else:
seg_logit = self.whole_inference(img, img_meta, rescale)
output = F.softmax(seg_logit, dim=1)
if self.num_classes == 1:
Copy link
Collaborator

@MeowZheng MeowZheng Sep 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if self.num_classes == 1:
if self.out_channels == 1:

It would be better to use out_channels to define the output channels or post-processes rather than num classes

@@ -260,7 +263,11 @@ def inference(self, img, img_meta, rescale):
def simple_test(self, img, img_meta, rescale=True):
"""Simple test with single image."""
seg_logit = self.inference(img, img_meta, rescale)
seg_pred = seg_logit.argmax(dim=1)
if self.num_classes == 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same problems as above

@@ -283,7 +290,11 @@ def aug_test(self, imgs, img_metas, rescale=True):
cur_seg_logit = self.inference(imgs[i], img_metas[i], rescale)
seg_logit += cur_seg_logit
seg_logit /= len(imgs)
seg_pred = seg_logit.argmax(dim=1)
if self.num_classes == 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same problems as above

Comment on lines 94 to 98
if self.num_classes == 2 and self.out_channels is None:
warnings.warn('set out_channels to 1 to reduce the number of model\
parameters')
elif self.num_classes == 2 and self.out_channels == 1:
self.num_classes = 1
Copy link
Collaborator

@MeowZheng MeowZheng Sep 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if self. out_channels is None:
   out_channels = num_classes
    if num_classes == 2:
        warning.warn('For binary segmentation, we suggest using `out_channels = 1` to define the output channels of segmentor', and use `threshold` to  convert seg_logist into a prediction applying a threshold
if out_channels == 1 and threshold is None:
    threshold = 0.3
    warning.warn('threshold is not defined for binary, and defaults to 0.3)
self.num_classes = num_classes
self.out_channels = out_channels
self.threshold=threshold

I suggest don't change num_classes value and tell user the default threshold is 0.3

Comment on lines 22 to 23
threshold (float): Threshold for binary segmentation in the case of
`num_classes==1`. Default: 0.3.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

threshold default to None.

Comment on lines 83 to 84
self.out_channels = out_channels
self.threshold = threshold
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.out_channels = out_channels
self.threshold = threshold

@@ -56,6 +59,8 @@ def __init__(self,
channels,
*,
num_classes,
out_channels=None,
threshold=0.3,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

threshold = None.

@MeowZheng MeowZheng mentioned this pull request Sep 8, 2022
if out_channels != num_classes and out_channels != 1:
raise ValueError(
'out_channels should be equal to num_classes,'
'except out_channels == 1 and num_classes == 2, but got'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
'except out_channels == 1 and num_classes == 2, but got'
'except binary segmentation set out_channels == 1 and num_classes == 2, but got'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants