Skip to content

Commit

Permalink
Fixes more typing errors
Browse files Browse the repository at this point in the history
  • Loading branch information
marksgraham committed Jul 31, 2023
1 parent 4b1d801 commit cd01b59
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
8 changes: 5 additions & 3 deletions monai/losses/adversarial_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(
# Depending on the criterion, a different activation layer is used.
self.real_label = 1.0
self.fake_label = 0.0
self.loss_fct : _Loss
self.loss_fct: _Loss
if criterion == AdversarialCriterions.BCE.value:
self.activation = get_act_layer("SIGMOID")
self.loss_fct = torch.nn.BCELoss(reduction=reduction)
Expand Down Expand Up @@ -153,16 +153,18 @@ def forward(
loss_ = self.forward_single(disc_out, target_[disc_ind])
loss_list.append(loss_)

loss: torch.Tensor | list[torch.Tensor]
if loss_list is not None:
if self.reduction == LossReduction.MEAN.value:
loss = torch.mean(torch.stack(loss_list))
elif self.reduction == LossReduction.SUM.value:
loss = torch.sum(torch.stack(loss_list))

else:
loss = loss_list
return loss

def forward_single(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
forward : torch.Tensor
forward: torch.Tensor
if (
self.criterion == AdversarialCriterions.BCE.value
or self.criterion == AdversarialCriterions.LEAST_SQUARE.value
Expand Down
7 changes: 3 additions & 4 deletions monai/losses/perceptual.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
torchvision, _ = optional_import("torchvision")



class PerceptualLoss(nn.Module):
"""
Perceptual loss using features from pretrained deep neural networks trained. The function supports networks
Expand Down Expand Up @@ -78,7 +77,7 @@ def __init__(
torch.hub.set_dir(cache_dir)

self.spatial_dims = spatial_dims
self.perceptual_function : nn.Module
self.perceptual_function: nn.Module
if spatial_dims == 3 and is_fake_3d is False:
self.perceptual_function = MedicalNetPerceptualSimilarity(net=network_type, verbose=False)
elif "radimagenet_" in network_type:
Expand Down Expand Up @@ -196,7 +195,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
feats_input = normalize_tensor(outs_input)
feats_target = normalize_tensor(outs_target)

results : torch.Tensor = (feats_input - feats_target) ** 2
results: torch.Tensor = (feats_input - feats_target) ** 2
results = spatial_average_3d(results.sum(dim=1, keepdim=True), keepdim=True)

return results
Expand Down Expand Up @@ -345,7 +344,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
feats_input = normalize_tensor(outs_input)
feats_target = normalize_tensor(outs_target)

results : torch.Tensor = (feats_input - feats_target) ** 2
results: torch.Tensor = (feats_input - feats_target) ** 2
results = spatial_average(results.sum(dim=1, keepdim=True), keepdim=True)

return results
Expand Down

0 comments on commit cd01b59

Please sign in to comment.