Skip to content

Commit abc58b3

Browse files
JHancoxbhashemian
authored andcommitted
Add support for <=PyTorch1.9.1 (#5467)
Fixes #5469 ### Description For CrossEntropy, PyTorch v1.10+ allows target values to be either class probabilities or indices, whereas v1.9- only allows for class indices, so class indices are now used. This should provide a better solution for #5393 and remove the need for #5401 ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). Signed-off-by: JHancox <48477639+JHancox@users.noreply.github.com> Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> Co-authored-by: Behrooz <3968947+drbeh@users.noreply.github.com>
1 parent 798251b commit abc58b3

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

monai/apps/pathology/losses/hovernet_loss.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,9 @@ def forward(self, prediction: Dict[str, torch.Tensor], target: Dict[str, torch.T
121121
dice_loss_np = (
122122
self.dice(prediction[HoVerNetBranch.NP.value], target[HoVerNetBranch.NP.value]) * self.lambda_np_dice
123123
)
124-
ce_loss_np = self.ce(prediction[HoVerNetBranch.NP.value], target[HoVerNetBranch.NP.value]) * self.lambda_np_ce
124+
# convert to target class indices
125+
argmax_target = target[HoVerNetBranch.NP.value].argmax(dim=1)
126+
ce_loss_np = self.ce(prediction[HoVerNetBranch.NP.value], argmax_target) * self.lambda_np_ce
125127
loss_np = dice_loss_np + ce_loss_np
126128

127129
# Compute the HV branch loss
@@ -146,9 +148,9 @@ def forward(self, prediction: Dict[str, torch.Tensor], target: Dict[str, torch.T
146148
dice_loss_nc = (
147149
self.dice(prediction[HoVerNetBranch.NC.value], target[HoVerNetBranch.NC.value]) * self.lambda_nc_dice
148150
)
149-
ce_loss_nc = (
150-
self.ce(prediction[HoVerNetBranch.NC.value], target[HoVerNetBranch.NC.value]) * self.lambda_nc_ce
151-
)
151+
# Convert to target class indices
152+
argmax_target = target[HoVerNetBranch.NC.value].argmax(dim=1)
153+
ce_loss_nc = self.ce(prediction[HoVerNetBranch.NC.value], argmax_target) * self.lambda_nc_ce
152154
loss_nc = dice_loss_nc + ce_loss_nc
153155

154156
# Sum the losses from each branch

tests/test_hovernet_loss.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from monai.transforms import GaussianSmooth, Rotate
2222
from monai.transforms.intensity.array import ComputeHoVerMaps
2323
from monai.utils.enums import HoVerNetBranch
24-
from tests.utils import SkipIfBeforePyTorchVersion
2524

2625
device = "cuda" if torch.cuda.is_available() else "cpu"
2726

@@ -169,7 +168,6 @@ def test_shape_generator(num_classes=1, num_objects=3, batch_size=1, height=5, w
169168
]
170169

171170

172-
@SkipIfBeforePyTorchVersion((1, 10))
173171
class TestHoverNetLoss(unittest.TestCase):
174172
@parameterized.expand(CASES)
175173
def test_shape(self, input_param, expected_loss):

0 commit comments

Comments
 (0)