Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add weight in DiceLoss #7098

Merged
merged 16 commits into from
Oct 8, 2023
Merged

Add weight in DiceLoss #7098

merged 16 commits into from
Oct 8, 2023

Conversation

KumoLiu
Copy link
Contributor

@KumoLiu KumoLiu commented Oct 7, 2023

Fixes #7065.

Description

  • standardize the naming to be simply "weight".
  • add this "weight" parameter to DiceLoss.

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

KumoLiu and others added 7 commits October 7, 2023 16:58
Signed-off-by: KumoLiu <yunl@nvidia.com>
Signed-off-by: KumoLiu <yunl@nvidia.com>
Signed-off-by: KumoLiu <yunl@nvidia.com>
Signed-off-by: KumoLiu <yunl@nvidia.com>
Signed-off-by: KumoLiu <yunl@nvidia.com>
Signed-off-by: KumoLiu <yunl@nvidia.com>
@KumoLiu KumoLiu marked this pull request as ready for review October 8, 2023 01:54
@KumoLiu KumoLiu requested review from wyli, yiheng-wang-nv, myron and Nic-Ma and removed request for yiheng-wang-nv October 8, 2023 01:57
monai/losses/dice.py Outdated Show resolved Hide resolved
KumoLiu and others added 7 commits October 8, 2023 15:24
Signed-off-by: KumoLiu <yunl@nvidia.com>
Signed-off-by: KumoLiu <yunl@nvidia.com>
Signed-off-by: KumoLiu <yunl@nvidia.com>
Signed-off-by: KumoLiu <yunl@nvidia.com>
Signed-off-by: KumoLiu <yunl@nvidia.com>
Copy link
Contributor

@wyli wyli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, it looks good to me, cc @myron feel free to create further request for any concerns

monai/losses/dice.py Outdated Show resolved Hide resolved
Signed-off-by: KumoLiu <yunl@nvidia.com>
@wyli
Copy link
Contributor

wyli commented Oct 8, 2023

/build

@wyli wyli enabled auto-merge (squash) October 8, 2023 09:37
@wyli wyli merged commit 7930f85 into Project-MONAI:dev Oct 8, 2023
@KumoLiu KumoLiu deleted the loss-weights branch October 9, 2023 01:55
@myron
Copy link
Collaborator

myron commented Oct 19, 2023

@KumoLiu

Thanks for adding. But why is "self.class_weight" a property and registered as buffer? We recompute it every time on the fly from self.weight. I think "class_weight" can be just an intermediate variable, to avoid confusion. thanks

@KumoLiu
Copy link
Contributor Author

KumoLiu commented Oct 19, 2023

Hi @myron, I think the main reason for registering the class_weight as a buffer is that "class_weight" doesn't require training optimization.

@myron
Copy link
Collaborator

myron commented Oct 22, 2023

@KumoLiu, "self.class_weight" does not require optimization, also it is fully defined by "self.weight" and is recomputed every time in your implementation. As far as I can tell, there is no reason to define "self.class_weight" to be a class property of DiceLoss() at all. We never update that property directly.

Unless there is a reason, it should not be a property of DiceLoss. Please remove it as class property, you can simply have an intermediate local variable called "class_weight" during loss calculation. And in general, if it a simple local variable, do not declare new class properties, it will be very confusing for users. CC @wyli @Nic-Ma

https://github.com/KumoLiu/MONAI/blob/55416891b778afa94d495c538acc16e47a445f02/monai/losses/dice.py#L192-L208

@wyli
Copy link
Contributor

wyli commented Oct 22, 2023

Thanks @myron, this is to be consistent with the pytorch weighted loss interface and with the benefit of saving the weights as part of training stats: https://github.com/pytorch/pytorch/blob/798efab53274ff44d0b5bbd2de59299b529e757c/torch/nn/modules/loss.py#L28

If there's significant computional overheads we can look into refactoring.

@myron
Copy link
Collaborator

myron commented Oct 22, 2023

There is no compute overhead, but it looks like a bad coding practice to me. We have self.weight and self.class_weight now, which both represent the same thing, and furthermore self.class_weight is recomputed every time from self.weight in the forward() pass.

  1. If we are to be consistent with PyT weighted loss, then we already have "self.weight" property. What is the benefit of declaring another one "self.class_weight".
  2. do we really save self.weight as part of training stats somewhere? (we just introduced self.weight) I'm interested to know, where do we save it and what do we save: "self.weight" or "self.class_weight".

Currently, it looks confusing, and a user may attempt to change dice_loss.class_weight=.. property directly, which will not accomplish anything. We should try to simplify our code, unless I'm missing the benefit of this extra property. thanks

@wyli
Copy link
Contributor

wyli commented Oct 23, 2023

perhaps the two variables self.weight and self.class_weight could be merged into one registered buffer, or could we remove self.weight? cc @KumoLiu

(also this discussion has more details about using the usage https://discuss.pytorch.org/t/what-is-the-difference-between-register-buffer-and-register-parameter-of-nn-module/32723)

@KumoLiu
Copy link
Contributor Author

KumoLiu commented Oct 24, 2023

Thanks @wyli and @myron for the suggestions. I have created another PR to address it. I removed self.weight and now we only have one registered buffer. Please check whether it makes sense to you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

DiceLoss add weight per class parameter
3 participants