-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add weight in DiceLoss
#7098
Add weight in DiceLoss
#7098
Conversation
Signed-off-by: KumoLiu <yunl@nvidia.com>
Signed-off-by: KumoLiu <yunl@nvidia.com>
Signed-off-by: KumoLiu <yunl@nvidia.com>
Signed-off-by: KumoLiu <yunl@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: KumoLiu <yunl@nvidia.com>
Signed-off-by: KumoLiu <yunl@nvidia.com>
Signed-off-by: KumoLiu <yunl@nvidia.com>
Signed-off-by: KumoLiu <yunl@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: KumoLiu <yunl@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks, it looks good to me, cc @myron feel free to create further request for any concerns
Signed-off-by: KumoLiu <yunl@nvidia.com>
/build |
Thanks for adding. But why is "self.class_weight" a property and registered as buffer? We recompute it every time on the fly from self.weight. I think "class_weight" can be just an intermediate variable, to avoid confusion. thanks |
Hi @myron, I think the main reason for registering the class_weight as a buffer is that "class_weight" doesn't require training optimization. |
@KumoLiu, "self.class_weight" does not require optimization, also it is fully defined by "self.weight" and is recomputed every time in your implementation. As far as I can tell, there is no reason to define "self.class_weight" to be a class property of DiceLoss() at all. We never update that property directly. Unless there is a reason, it should not be a property of DiceLoss. Please remove it as class property, you can simply have an intermediate local variable called "class_weight" during loss calculation. And in general, if it a simple local variable, do not declare new class properties, it will be very confusing for users. CC @wyli @Nic-Ma |
Thanks @myron, this is to be consistent with the pytorch weighted loss interface and with the benefit of saving the weights as part of training stats: https://github.com/pytorch/pytorch/blob/798efab53274ff44d0b5bbd2de59299b529e757c/torch/nn/modules/loss.py#L28 If there's significant computional overheads we can look into refactoring. |
There is no compute overhead, but it looks like a bad coding practice to me. We have self.weight and self.class_weight now, which both represent the same thing, and furthermore self.class_weight is recomputed every time from self.weight in the forward() pass.
Currently, it looks confusing, and a user may attempt to change dice_loss.class_weight=.. property directly, which will not accomplish anything. We should try to simplify our code, unless I'm missing the benefit of this extra property. thanks |
perhaps the two variables (also this discussion has more details about using the usage https://discuss.pytorch.org/t/what-is-the-difference-between-register-buffer-and-register-parameter-of-nn-module/32723) |
Fixes #7065.
Description
DiceLoss
.Types of changes
./runtests.sh -f -u --net --coverage
../runtests.sh --quick --unittests --disttests
.make html
command in thedocs/
folder.