Skip to content

Commit

Permalink
Fixes typing issues in adversarial loss
Browse files Browse the repository at this point in the history
  • Loading branch information
marksgraham committed Jul 31, 2023
1 parent 105c3b8 commit 700096d
Showing 1 changed file with 14 additions and 13 deletions.
27 changes: 14 additions & 13 deletions monai/losses/adversarial_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(
# Depending on the criterion, a different activation layer is used.
self.real_label = 1.0
self.fake_label = 0.0
self.loss_fct : _Loss
if criterion == AdversarialCriterions.BCE.value:
self.activation = get_act_layer("SIGMOID")
self.loss_fct = torch.nn.BCELoss(reduction=reduction)
Expand All @@ -80,7 +81,7 @@ def __init__(
self.criterion = criterion
self.reduction = reduction

def get_target_tensor(self, input: torch.FloatTensor, target_is_real: bool) -> torch.Tensor:
def get_target_tensor(self, input: torch.Tensor, target_is_real: bool) -> torch.Tensor:
"""
Gets the ground truth tensor for the discriminator depending on whether the input is real or fake.
Expand All @@ -95,7 +96,7 @@ def get_target_tensor(self, input: torch.FloatTensor, target_is_real: bool) -> t
label_tensor.requires_grad_(False)
return label_tensor.expand_as(input)

def get_zero_tensor(self, input: torch.FloatTensor) -> torch.Tensor:
def get_zero_tensor(self, input: torch.Tensor) -> torch.Tensor:
"""
Gets a zero tensor.
Expand All @@ -109,7 +110,7 @@ def get_zero_tensor(self, input: torch.FloatTensor) -> torch.Tensor:
return zero_label_tensor.expand_as(input)

def forward(
self, input: torch.FloatTensor | list, target_is_real: bool, for_discriminator: bool
self, input: torch.Tensor | list, target_is_real: bool, for_discriminator: bool
) -> torch.Tensor | list[torch.Tensor]:
"""
Expand Down Expand Up @@ -142,32 +143,32 @@ def forward(
target_.append(self.get_zero_tensor(disc_out))

# Loss calculation
loss = []
loss_list = []
for disc_ind, disc_out in enumerate(input):
if self.activation is not None:
disc_out = self.activation(disc_out)
if self.criterion == AdversarialCriterions.HINGE.value and not target_is_real:
loss_ = self.forward_single(-disc_out, target_[disc_ind])
else:
loss_ = self.forward_single(disc_out, target_[disc_ind])
loss.append(loss_)
loss_list.append(loss_)

if loss is not None:
if loss_list is not None:
if self.reduction == LossReduction.MEAN.value:
loss = torch.mean(torch.stack(loss))
loss = torch.mean(torch.stack(loss_list))
elif self.reduction == LossReduction.SUM.value:
loss = torch.sum(torch.stack(loss))
loss = torch.sum(torch.stack(loss_list))

return loss

def forward_single(self, input: torch.FloatTensor, target: torch.FloatTensor) -> torch.Tensor | None:
def forward_single(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
forward : torch.Tensor
if (
self.criterion == AdversarialCriterions.BCE.value
or self.criterion == AdversarialCriterions.LEAST_SQUARE.value
):
return self.loss_fct(input, target)
forward = self.loss_fct(input, target)
elif self.criterion == AdversarialCriterions.HINGE.value:
minval = torch.min(input - 1, self.get_zero_tensor(input))
return -torch.mean(minval)
else:
return None
forward = -torch.mean(minval)
return forward

0 comments on commit 700096d

Please sign in to comment.