Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Signed-off-by: Wenqi Li <wenqil@nvidia.com>
  • Loading branch information
wyli committed Jan 29, 2021
1 parent c435327 commit 46fc7c0
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion monai/losses/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ def _compute_alpha_generalized_true_positives(self, flat_target: torch.Tensor) -
return alpha


class DiceCELoss:
class DiceCELoss(_Loss):
"""
Compute both Dice loss and Cross Entropy Loss, and return the sum of these two losses.
Input logits `input` (BNHW[D] where N is number of classes) is compared with ground truth `target` (BNHW[D]).
Expand Down
4 changes: 2 additions & 2 deletions tests/test_dice_ce_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,13 @@
class TestDiceCELoss(unittest.TestCase):
@parameterized.expand(TEST_CASES)
def test_result(self, input_param, input_data, expected_val):
result = DiceCELoss(**input_param).forward(**input_data)
result = DiceCELoss(**input_param)(**input_data)
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4)

def test_ill_shape(self):
loss = DiceCELoss()
with self.assertRaisesRegex(ValueError, ""):
loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))
loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))


if __name__ == "__main__":
Expand Down

0 comments on commit 46fc7c0

Please sign in to comment.