Skip to content

Commit

Permalink
Mypy fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
marksgraham committed Aug 8, 2023
1 parent 917e504 commit 3f9a492
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions monai/metrics/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,11 @@ def compute_ms_ssim(
f" spatial dimensions, got {dims}."
)

if not isinstance(kernel_size, Sequence):
kernel_size = ensure_tuple_rep(kernel_size, spatial_dims)

if not isinstance(kernel_sigma, Sequence):
kernel_sigma = ensure_tuple_rep(kernel_sigma, spatial_dims)
# check if image have enough size for the number of downsamplings and the size of the kernel
weights_div = max(1, (len(weights) - 1)) ** 2
y_pred_spatial_dims = y_pred.shape[2:]
Expand All @@ -566,12 +571,12 @@ def compute_ms_ssim(
f"{(kernel_size[i] - 1) * weights_div}."
)

weights = torch.tensor(weights, device=y_pred.device, dtype=torch.float)
weights_tensor = torch.tensor(weights, device=y_pred.device, dtype=torch.float)

avg_pool = getattr(F, f"avg_pool{spatial_dims}d")

multiscale_list: list[torch.Tensor] = []
for _ in range(len(weights)):
for _ in range(len(weights_tensor)):
ssim, cs = compute_ssim_and_cs(
y_pred=y_pred,
y=y,
Expand All @@ -592,9 +597,9 @@ def compute_ms_ssim(

ssim = ssim.view(ssim.shape[0], -1).mean(1)
multiscale_list[-1] = torch.relu(ssim)
multiscale_list = torch.stack(multiscale_list)
multiscale_list_tensor = torch.stack(multiscale_list)

ms_ssim_value_full_image = torch.prod(multiscale_list ** weights.view(-1, 1), dim=0)
ms_ssim_value_full_image = torch.prod(multiscale_list_tensor ** weights_tensor.view(-1, 1), dim=0)

ms_ssim_per_batch: torch.Tensor = ms_ssim_value_full_image.view(ms_ssim_value_full_image.shape[0], -1).mean(
1, keepdim=True
Expand Down

0 comments on commit 3f9a492

Please sign in to comment.