Skip to content

Commit

Permalink
Add channelwise flag to perceputal loss
Browse files Browse the repository at this point in the history
  • Loading branch information
SomeUserName1 committed Mar 25, 2024
1 parent 0250284 commit dd84afc
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 9 deletions.
40 changes: 33 additions & 7 deletions monai/losses/perceptual.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __init__(
pretrained: bool = True,
pretrained_path: str | None = None,
pretrained_state_dict_key: str | None = None,
channelwise: bool = False,
):
super().__init__()

Expand Down Expand Up @@ -102,15 +103,18 @@ def __init__(
self.spatial_dims = spatial_dims
self.perceptual_function: nn.Module
if spatial_dims == 3 and is_fake_3d is False:
self.perceptual_function = MedicalNetPerceptualSimilarity(net=network_type, verbose=False)
self.perceptual_function = MedicalNetPerceptualSimilarity(net=network_type, verbose=False,
channelwise=channelwise)
elif "radimagenet_" in network_type:
self.perceptual_function = RadImageNetPerceptualSimilarity(net=network_type, verbose=False)
self.perceptual_function = RadImageNetPerceptualSimilarity(net=network_type, verbose=False,
channelwise=channelwise)
elif network_type == "resnet50":
self.perceptual_function = TorchvisionModelPerceptualSimilarity(
net=network_type,
pretrained=pretrained,
pretrained_path=pretrained_path,
pretrained_state_dict_key=pretrained_state_dict_key,
channelwise=channelwise,
)
else:
self.perceptual_function = LPIPS(pretrained=pretrained, net=network_type, verbose=False)
Expand Down Expand Up @@ -185,14 +189,21 @@ class MedicalNetPerceptualSimilarity(nn.Module):
net: {``"medicalnet_resnet10_23datasets"``, ``"medicalnet_resnet50_23datasets"``}
Specifies the network architecture to use. Defaults to ``"medicalnet_resnet10_23datasets"``.
verbose: if false, mute messages from torch Hub load function.
channelwise: if True, the loss is returned per channel. Otherwise the loss is averaged over the channels.
Defaults to ``False``.
"""

def __init__(self, net: str = "medicalnet_resnet10_23datasets", verbose: bool = False) -> None:
def __init__(self,
net: str = "medicalnet_resnet10_23datasets",
verbose: bool = False,
channelwise: bool = False) -> None:
super().__init__()
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
self.model = torch.hub.load("warvito/MedicalNet-models", model=net, verbose=verbose)
self.eval()

self.channelwise = channelwise

for param in self.parameters():
param.requires_grad = False

Expand All @@ -206,6 +217,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
Args:
input: 3D input tensor with shape BCDHW.
target: 3D target tensor with shape BCDHW.
"""
input = medicalnet_intensity_normalisation(input)
target = medicalnet_intensity_normalisation(target)
Expand All @@ -227,7 +239,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
feats_target = normalize_tensor(outs_target)

results: torch.Tensor = (feats_input - feats_target) ** 2
results = spatial_average_3d(results.sum(dim=1, keepdim=True), keepdim=True)

if self.channelwise:
results = results.sum(dim=1, keepdim=True)
results = spatial_average_3d(results, keepdim=True)

return results

Expand Down Expand Up @@ -260,11 +275,13 @@ class RadImageNetPerceptualSimilarity(nn.Module):
verbose: if false, mute messages from torch Hub load function.
"""

def __init__(self, net: str = "radimagenet_resnet50", verbose: bool = False) -> None:
def __init__(self, net: str = "radimagenet_resnet50", verbose: bool = False, channelwise: bool = False) -> None:
super().__init__()
self.model = torch.hub.load("Warvito/radimagenet-models", model=net, verbose=verbose)
self.eval()

self.channelwise = channelwise

for param in self.parameters():
param.requires_grad = False

Expand Down Expand Up @@ -297,7 +314,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
feats_target = normalize_tensor(outs_target)

results: torch.Tensor = (feats_input - feats_target) ** 2
results = spatial_average(results.sum(dim=1, keepdim=True), keepdim=True)

if self.channelwise:
results = results.sum(dim=1, keepdim=True)
results = spatial_average(results, keepdim=True)

return results

Expand All @@ -324,6 +344,7 @@ def __init__(
pretrained: bool = True,
pretrained_path: str | None = None,
pretrained_state_dict_key: str | None = None,
channelwise: bool = False,
) -> None:
super().__init__()
supported_networks = ["resnet50"]
Expand All @@ -347,6 +368,8 @@ def __init__(
self.model = torchvision.models.feature_extraction.create_feature_extractor(network, [self.final_layer])
self.eval()

self.channelwise = channelwise

for param in self.parameters():
param.requires_grad = False

Expand Down Expand Up @@ -376,7 +399,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
feats_target = normalize_tensor(outs_target)

results: torch.Tensor = (feats_input - feats_target) ** 2
results = spatial_average(results.sum(dim=1, keepdim=True), keepdim=True)

if self.channelwise:
results = results.sum(dim=1, keepdim=True)
results = spatial_average(results, keepdim=True)

return results

Expand Down
29 changes: 27 additions & 2 deletions tests/test_perceptual_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,18 @@
],
[{"spatial_dims": 2, "network_type": "radimagenet_resnet50"}, (2, 1, 64, 64), (2, 1, 64, 64)],
[{"spatial_dims": 2, "network_type": "radimagenet_resnet50"}, (2, 3, 64, 64), (2, 3, 64, 64)],
[{"spatial_dims": 2, "network_type": "radimagenet_resnet50", "channelwise": True}, (2, 3, 64, 64), (2, 3, 64, 64)],
[
{"spatial_dims": 3, "network_type": "radimagenet_resnet50", "is_fake_3d": True, "fake_3d_ratio": 0.1},
(2, 1, 64, 64, 64),
(2, 1, 64, 64, 64),
],
[
{"spatial_dims": 3, "network_type": "radimagenet_resnet50", "is_fake_3d": True, "fake_3d_ratio": 0.1,
'channelwise': True},
(2, 1, 64, 64, 64),
(2, 1, 64, 64, 64),
],
[
{"spatial_dims": 3, "network_type": "medicalnet_resnet10_23datasets", "is_fake_3d": False},
(2, 1, 64, 64, 64),
Expand All @@ -45,6 +52,11 @@
(2, 6, 64, 64, 64),
(2, 6, 64, 64, 64),
],
[
{"spatial_dims": 3, "network_type": "medicalnet_resnet10_23datasets", "is_fake_3d": False, "channelwise": True},
(2, 6, 64, 64, 64),
(2, 6, 64, 64, 64),
],
[
{"spatial_dims": 3, "network_type": "medicalnet_resnet50_23datasets", "is_fake_3d": False},
(2, 1, 64, 64, 64),
Expand All @@ -60,6 +72,11 @@
(2, 1, 64, 64, 64),
(2, 1, 64, 64, 64),
],
[
{"spatial_dims": 3, "network_type": "resnet50", "pretrained": True, "fake_3d_ratio": 0.2, "channelwise": True},
(2, 3, 64, 64, 64),
(2, 3, 64, 64, 64),
],
]


Expand All @@ -73,15 +90,23 @@ def test_shape(self, input_param, input_shape, target_shape):
with skip_if_downloading_fails():
loss = PerceptualLoss(**input_param)
result = loss(torch.randn(input_shape), torch.randn(target_shape))
self.assertEqual(result.shape, torch.Size([]))

if 'channelwise' in input_param.keys() and input_param['channelwise']:
self.assertEqual(result.shape, torch.Size([input_shape[1]]))
else:
self.assertEqual(result.shape, torch.Size([]))

@parameterized.expand(TEST_CASES)
def test_identical_input(self, input_param, input_shape, target_shape):
with skip_if_downloading_fails():
loss = PerceptualLoss(**input_param)
tensor = torch.randn(input_shape)
result = loss(tensor, tensor)
self.assertEqual(result, torch.Tensor([0.0]))

if 'channelwise' in input_param.keys() and input_param['channelwise']:
self.assertEqual(result, torch.Tensor([0.0] * input_shape[1]))
else:
self.assertEqual(result, torch.Tensor([0.0]))

def test_different_shape(self):
with skip_if_downloading_fails():
Expand Down

0 comments on commit dd84afc

Please sign in to comment.