Skip to content

Commit

Permalink
Refactor self.weight in DiceLoss and FocalLoss (#7158)
Browse files Browse the repository at this point in the history
Fixes #7065

### Description
Remove `self.weight` in `DiceLoss` and `FocalLoss`

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: KumoLiu <yunl@nvidia.com>
  • Loading branch information
KumoLiu authored Oct 24, 2023
1 parent e7feedf commit cc20c9b
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 13 deletions.
14 changes: 7 additions & 7 deletions monai/losses/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,9 @@ def __init__(
self.smooth_nr = float(smooth_nr)
self.smooth_dr = float(smooth_dr)
self.batch = batch
self.weight = weight
self.register_buffer("class_weight", torch.ones(1))
weight = torch.as_tensor(weight) if weight is not None else None
self.register_buffer("class_weight", weight)
self.class_weight: None | torch.Tensor

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Expand Down Expand Up @@ -189,13 +190,12 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:

f: torch.Tensor = 1.0 - (2.0 * intersection + self.smooth_nr) / (denominator + self.smooth_dr)

if self.weight is not None and target.shape[1] != 1:
num_of_classes = target.shape[1]
if self.class_weight is not None and num_of_classes != 1:
# make sure the lengths of weights are equal to the number of classes
num_of_classes = target.shape[1]
if isinstance(self.weight, (float, int)):
self.class_weight = torch.as_tensor([self.weight] * num_of_classes)
if self.class_weight.ndim == 0:
self.class_weight = torch.as_tensor([self.class_weight] * num_of_classes)
else:
self.class_weight = torch.as_tensor(self.weight)
if self.class_weight.shape[0] != num_of_classes:
raise ValueError(
"""the length of the `weight` sequence should be the same as the number of classes.
Expand Down
13 changes: 7 additions & 6 deletions monai/losses/focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,9 @@ def __init__(
self.alpha = alpha
self.weight = weight
self.use_softmax = use_softmax
self.register_buffer("class_weight", torch.ones(1))
weight = torch.as_tensor(weight) if weight is not None else None
self.register_buffer("class_weight", weight)
self.class_weight: None | torch.Tensor

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Expand Down Expand Up @@ -162,13 +164,12 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
else:
loss = sigmoid_focal_loss(input, target, self.gamma, self.alpha)

if self.weight is not None:
num_of_classes = target.shape[1]
if self.class_weight is not None and num_of_classes != 1:
# make sure the lengths of weights are equal to the number of classes
num_of_classes = target.shape[1]
if isinstance(self.weight, (float, int)):
self.class_weight = torch.as_tensor([self.weight] * num_of_classes)
if self.class_weight.ndim == 0:
self.class_weight = torch.as_tensor([self.class_weight] * num_of_classes)
else:
self.class_weight = torch.as_tensor(self.weight)
if self.class_weight.shape[0] != num_of_classes:
raise ValueError(
"""the length of the `weight` sequence should be the same as the number of classes.
Expand Down

0 comments on commit cc20c9b

Please sign in to comment.