diff --git a/monai/losses/adversarial_loss.py b/monai/losses/adversarial_loss.py index 62cff46200..2165fc8daa 100644 --- a/monai/losses/adversarial_loss.py +++ b/monai/losses/adversarial_loss.py @@ -64,6 +64,7 @@ def __init__( # Depending on the criterion, a different activation layer is used. self.real_label = 1.0 self.fake_label = 0.0 + self.loss_fct : _Loss if criterion == AdversarialCriterions.BCE.value: self.activation = get_act_layer("SIGMOID") self.loss_fct = torch.nn.BCELoss(reduction=reduction) @@ -80,7 +81,7 @@ def __init__( self.criterion = criterion self.reduction = reduction - def get_target_tensor(self, input: torch.FloatTensor, target_is_real: bool) -> torch.Tensor: + def get_target_tensor(self, input: torch.Tensor, target_is_real: bool) -> torch.Tensor: """ Gets the ground truth tensor for the discriminator depending on whether the input is real or fake. @@ -95,7 +96,7 @@ def get_target_tensor(self, input: torch.FloatTensor, target_is_real: bool) -> t label_tensor.requires_grad_(False) return label_tensor.expand_as(input) - def get_zero_tensor(self, input: torch.FloatTensor) -> torch.Tensor: + def get_zero_tensor(self, input: torch.Tensor) -> torch.Tensor: """ Gets a zero tensor. @@ -109,7 +110,7 @@ def get_zero_tensor(self, input: torch.FloatTensor) -> torch.Tensor: return zero_label_tensor.expand_as(input) def forward( - self, input: torch.FloatTensor | list, target_is_real: bool, for_discriminator: bool + self, input: torch.Tensor | list, target_is_real: bool, for_discriminator: bool ) -> torch.Tensor | list[torch.Tensor]: """ @@ -142,7 +143,7 @@ def forward( target_.append(self.get_zero_tensor(disc_out)) # Loss calculation - loss = [] + loss_list = [] for disc_ind, disc_out in enumerate(input): if self.activation is not None: disc_out = self.activation(disc_out) @@ -150,24 +151,24 @@ def forward( loss_ = self.forward_single(-disc_out, target_[disc_ind]) else: loss_ = self.forward_single(disc_out, target_[disc_ind]) - loss.append(loss_) + loss_list.append(loss_) - if loss is not None: + if loss_list is not None: if self.reduction == LossReduction.MEAN.value: - loss = torch.mean(torch.stack(loss)) + loss = torch.mean(torch.stack(loss_list)) elif self.reduction == LossReduction.SUM.value: - loss = torch.sum(torch.stack(loss)) + loss = torch.sum(torch.stack(loss_list)) return loss - def forward_single(self, input: torch.FloatTensor, target: torch.FloatTensor) -> torch.Tensor | None: + def forward_single(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + forward : torch.Tensor if ( self.criterion == AdversarialCriterions.BCE.value or self.criterion == AdversarialCriterions.LEAST_SQUARE.value ): - return self.loss_fct(input, target) + forward = self.loss_fct(input, target) elif self.criterion == AdversarialCriterions.HINGE.value: minval = torch.min(input - 1, self.get_zero_tensor(input)) - return -torch.mean(minval) - else: - return None + forward = -torch.mean(minval) + return forward