-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Add support for the focal Tversky loss #2791
Conversation
Codecov ReportPatch coverage:
Additional details and impacted files@@ Coverage Diff @@
## master #2791 +/- ##
==========================================
- Coverage 88.13% 88.11% -0.02%
==========================================
Files 149 149
Lines 9183 9187 +4
Branches 1539 1540 +1
==========================================
+ Hits 8093 8095 +2
- Misses 835 836 +1
- Partials 255 256 +1
Flags with carried forward coverage won't be shown. Click here to find out more.
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @zifuwanggg,
Thanks for your contribution! We really appreciate it. Here are some comments that should be resolved, please take a view.
gamma (float, in [1, inf]): The focal term. When `gamma` > 1, | ||
the loss focuses more on less accurate predictions that | ||
have been misclassified. Default: 1.0. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might add the paper link https://arxiv.org/abs/1810.07842 to the docstring above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added the reference in the docstring.
loss_name='loss_tversky'): | ||
super(TverskyLoss, self).__init__() | ||
self.smooth = smooth | ||
self.class_weight = get_class_weight(class_weight) | ||
self.loss_weight = loss_weight | ||
self.ignore_index = ignore_index | ||
assert (alpha + beta == 1.0), 'Sum of alpha and beta but be 1.0!' | ||
assert gamma >= 1.0, 'gamma should be at least 1.0!' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since there is an assertion statement, we might add a unit test for it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added an assertion test.
mmseg/models/losses/tversky_loss.py
Outdated
if gamma > 1.0: | ||
tversky_loss **= gamma |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should it be 1 / gamma
?
At the official implementation https://github.com/nabsabraham/focal-tversky-unet/blob/master/losses.py#L67, they calculate the FTL with gamma=1/(4/3)=0.75
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes indeed. I've modified it to 1 / gamma
.
Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily get feedback. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers.
Motivation
The focal Tversky loss was proposed in https://arxiv.org/abs/1810.07842. It has nearly 600 citations and has been shown to be extremely useful for highly imbalanced (medical) datasets. To add support for the focal Tversky loss, only few lines of changes are needed for the Tversky loss.
Modification
Add
gamma
as (optional) argument in the constructor ofTverskyLoss
. This parameter is then passed totversky_loss
to compute the focal Tversky loss.BC-breaking (Optional)
Does the modification introduce changes that break the backward-compatibility of the downstream repos?
If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR.
Use cases (Optional)
If this PR introduces a new feature, it is better to list some use cases here, and update the documentation.
Checklist
Reopening of previous PR.