-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Add support for the focal Tversky loss #2783
Conversation
Zifu Wang seems not to be a GitHub user. You need a GitHub account to be able to sign the CLA. If you have already a GitHub account, please add the email address used for this commit to your account. You have signed the CLA already but the status is still pending? Let us recheck it. |
Codecov ReportPatch coverage:
Additional details and impacted files@@ Coverage Diff @@
## master #2783 +/- ##
==========================================
- Coverage 88.36% 88.34% -0.02%
==========================================
Files 149 149
Lines 9109 9112 +3
Branches 1523 1524 +1
==========================================
+ Hits 8049 8050 +1
- Misses 810 811 +1
- Partials 250 251 +1
Flags with carried forward coverage won't be shown. Click here to find out more.
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report in Codecov by Sentry. |
Hi @zifuwanggg, |
mmseg/models/losses/tversky_loss.py
Outdated
@@ -75,6 +78,9 @@ class TverskyLoss(nn.Module): | |||
beta (float, in [0, 1]): | |||
The coefficient of false negatives. Default: 0.7. | |||
Note: alpha + beta = 1. | |||
gamma (float, in [1, 3]): The focal term. When `gamma` > 1, the loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may add assertion below to ensure 1.0 < gamma < 3.0
, so does unit test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The [1, 3] range was mentioned in the paper, although they do not mention why gamma
should be in this range. Theoretically, gamma
can be any real number. Practically, a small gamma
(e.g. < 2) works better. mmseg.models.losses.focal_loss
does not specify the range of gamma
, so perhaps we should remove the requirement for gamma
in [1, 3] and let the user to decide?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Beside, I also think alpha
and beta
in the original TverskyLoss
class does not need to satisfy assert alpha + beta == 1.0
. Floating point equality is dangerous and according the definition of the Tversky index, the only requirement for alpha
and beta
is that they are non-negative numbers.
mmseg/models/losses/tversky_loss.py
Outdated
if gamma > 1: | ||
tversky_loss **= (1 / gamma) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might add a unit test.
Hi, @xiexinch Thanks for the comment. The email I used for the very first commits was indeed by default |
Hi @zifuwanggg, |
Hi |
Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily get feedback. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers. ## Motivation The focal Tversky loss was proposed in https://arxiv.org/abs/1810.07842. It has nearly 600 citations and has been shown to be extremely useful for highly imbalanced (medical) datasets. To add support for the focal Tversky loss, only few lines of changes are needed for the Tversky loss. ## Modification Add `gamma` as (optional) argument in the constructor of `TverskyLoss`. This parameter is then passed to `tversky_loss` to compute the focal Tversky loss. ## BC-breaking (Optional) Does the modification introduce changes that break the backward-compatibility of the downstream repos? If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR. ## Use cases (Optional) If this PR introduces a new feature, it is better to list some use cases here, and update the documentation. ## Checklist 1. Pre-commit or other linting tools are used to fix the potential lint issues. 2. The modification is covered by complete unit tests. If not, please add more unit test to ensure the correctness. 3. If the modification has potential influence on downstream projects, this PR should be tested with downstream projects, like MMDet or MMDet3D. 4. The documentation has been modified accordingly, like docstring or example tutorials. Reopening of previous [PR](#2783).
Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily get feedback. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers.
Motivation
The focal Tversky loss was proposed in https://arxiv.org/abs/1810.07842. It has nearly 600 citations and has been shown to be extremely useful for highly imbalanced (medical) datasets. To add support for the focal Tversky loss, only few lines of changes are needed for the Tversky loss.
Modification
Add
gamma
as (optional) argument in the constructor ofTverskyLoss
. This parameter is then passed totversky_loss
to compute the focal Tversky loss.BC-breaking (Optional)
Does the modification introduce changes that break the backward-compatibility of the downstream repos?
If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR.
Use cases (Optional)
If this PR introduces a new feature, it is better to list some use cases here, and update the documentation.
Checklist