From cd01b5955f4daec2f912d9d4d7a4859768186d6e Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Mon, 31 Jul 2023 15:26:06 +0100 Subject: [PATCH] Fixes more typing errors --- monai/losses/adversarial_loss.py | 8 +++++--- monai/losses/perceptual.py | 7 +++---- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/monai/losses/adversarial_loss.py b/monai/losses/adversarial_loss.py index 2165fc8daa..9005b7f030 100644 --- a/monai/losses/adversarial_loss.py +++ b/monai/losses/adversarial_loss.py @@ -64,7 +64,7 @@ def __init__( # Depending on the criterion, a different activation layer is used. self.real_label = 1.0 self.fake_label = 0.0 - self.loss_fct : _Loss + self.loss_fct: _Loss if criterion == AdversarialCriterions.BCE.value: self.activation = get_act_layer("SIGMOID") self.loss_fct = torch.nn.BCELoss(reduction=reduction) @@ -153,16 +153,18 @@ def forward( loss_ = self.forward_single(disc_out, target_[disc_ind]) loss_list.append(loss_) + loss: torch.Tensor | list[torch.Tensor] if loss_list is not None: if self.reduction == LossReduction.MEAN.value: loss = torch.mean(torch.stack(loss_list)) elif self.reduction == LossReduction.SUM.value: loss = torch.sum(torch.stack(loss_list)) - + else: + loss = loss_list return loss def forward_single(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - forward : torch.Tensor + forward: torch.Tensor if ( self.criterion == AdversarialCriterions.BCE.value or self.criterion == AdversarialCriterions.LEAST_SQUARE.value diff --git a/monai/losses/perceptual.py b/monai/losses/perceptual.py index 0b574395a3..e9a801c532 100644 --- a/monai/losses/perceptual.py +++ b/monai/losses/perceptual.py @@ -20,7 +20,6 @@ torchvision, _ = optional_import("torchvision") - class PerceptualLoss(nn.Module): """ Perceptual loss using features from pretrained deep neural networks trained. The function supports networks @@ -78,7 +77,7 @@ def __init__( torch.hub.set_dir(cache_dir) self.spatial_dims = spatial_dims - self.perceptual_function : nn.Module + self.perceptual_function: nn.Module if spatial_dims == 3 and is_fake_3d is False: self.perceptual_function = MedicalNetPerceptualSimilarity(net=network_type, verbose=False) elif "radimagenet_" in network_type: @@ -196,7 +195,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: feats_input = normalize_tensor(outs_input) feats_target = normalize_tensor(outs_target) - results : torch.Tensor = (feats_input - feats_target) ** 2 + results: torch.Tensor = (feats_input - feats_target) ** 2 results = spatial_average_3d(results.sum(dim=1, keepdim=True), keepdim=True) return results @@ -345,7 +344,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: feats_input = normalize_tensor(outs_input) feats_target = normalize_tensor(outs_target) - results : torch.Tensor = (feats_input - feats_target) ** 2 + results: torch.Tensor = (feats_input - feats_target) ** 2 results = spatial_average(results.sum(dim=1, keepdim=True), keepdim=True) return results