-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Fix] Fix binary segmentation when num_classes==1
#2016
Conversation
Codecov ReportBase: 89.08% // Head: 89.10% // Increases project coverage by
Additional details and impacted files@@ Coverage Diff @@
## master #2016 +/- ##
==========================================
+ Coverage 89.08% 89.10% +0.02%
==========================================
Files 145 145
Lines 8691 8711 +20
Branches 1463 1470 +7
==========================================
+ Hits 7742 7762 +20
Misses 707 707
Partials 242 242
Flags with carried forward coverage won't be shown. Click here to find out more.
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report at Codecov. |
@@ -245,7 +245,10 @@ def inference(self, img, img_meta, rescale): | |||
seg_logit = self.slide_inference(img, img_meta, rescale) | |||
else: | |||
seg_logit = self.whole_inference(img, img_meta, rescale) | |||
output = F.softmax(seg_logit, dim=1) | |||
if self.num_classes == 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if self.num_classes == 1: | |
if self.out_channels == 1: |
It would be better to use out_channels to define the output channels or post-processes rather than num classes
@@ -260,7 +263,11 @@ def inference(self, img, img_meta, rescale): | |||
def simple_test(self, img, img_meta, rescale=True): | |||
"""Simple test with single image.""" | |||
seg_logit = self.inference(img, img_meta, rescale) | |||
seg_pred = seg_logit.argmax(dim=1) | |||
if self.num_classes == 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same problems as above
@@ -283,7 +290,11 @@ def aug_test(self, imgs, img_metas, rescale=True): | |||
cur_seg_logit = self.inference(imgs[i], img_metas[i], rescale) | |||
seg_logit += cur_seg_logit | |||
seg_logit /= len(imgs) | |||
seg_pred = seg_logit.argmax(dim=1) | |||
if self.num_classes == 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same problems as above
if self.num_classes == 2 and self.out_channels is None: | ||
warnings.warn('set out_channels to 1 to reduce the number of model\ | ||
parameters') | ||
elif self.num_classes == 2 and self.out_channels == 1: | ||
self.num_classes = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if self. out_channels is None:
out_channels = num_classes
if num_classes == 2:
warning.warn('For binary segmentation, we suggest using `out_channels = 1` to define the output channels of segmentor', and use `threshold` to convert seg_logist into a prediction applying a threshold
if out_channels == 1 and threshold is None:
threshold = 0.3
warning.warn('threshold is not defined for binary, and defaults to 0.3)
self.num_classes = num_classes
self.out_channels = out_channels
self.threshold=threshold
I suggest don't change num_classes
value and tell user the default threshold is 0.3
threshold (float): Threshold for binary segmentation in the case of | ||
`num_classes==1`. Default: 0.3. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
threshold default to None.
self.out_channels = out_channels | ||
self.threshold = threshold |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.out_channels = out_channels | |
self.threshold = threshold |
@@ -56,6 +59,8 @@ def __init__(self, | |||
channels, | |||
*, | |||
num_classes, | |||
out_channels=None, | |||
threshold=0.3, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
threshold = None.
if out_channels != num_classes and out_channels != 1: | ||
raise ValueError( | ||
'out_channels should be equal to num_classes,' | ||
'except out_channels == 1 and num_classes == 2, but got' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
'except out_channels == 1 and num_classes == 2, but got' | |
'except binary segmentation set out_channels == 1 and num_classes == 2, but got' |
Motivation
Fix issue #1774
Modification
threshold
andout_channels
parameters todecode_head
num_classes==1
BC-breaking (Optional)
None