Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

6676 port generative metrics #6836

Merged
merged 18 commits into from
Aug 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,24 @@ Metrics
-------------------------------------
.. autoclass:: monai.metrics.regression.SSIMMetric

`Multi-scale structural similarity index measure`
-------------------------------------------------
.. autoclass:: MultiScaleSSIMMetric

`Fréchet Inception Distance`
------------------------------
.. autofunction:: compute_frechet_distance

.. autoclass:: FIDMetric
:members:

`Maximum Mean Discrepancy`
------------------------------
.. autofunction:: compute_mmd

.. autoclass:: MMDMetric
:members:

`Cumulative average`
--------------------
.. autoclass:: CumulativeAverage
Expand All @@ -156,6 +174,8 @@ Metrics
.. autoclass:: MetricsReloadedCategorical
:members:



Utilities
---------
.. automodule:: monai.metrics.utils
Expand Down
13 changes: 12 additions & 1 deletion monai/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,26 @@
from .confusion_matrix import ConfusionMatrixMetric, compute_confusion_matrix_metric, get_confusion_matrix
from .cumulative_average import CumulativeAverage
from .f_beta_score import FBetaScore
from .fid import FIDMetric, compute_frechet_distance
from .froc import compute_fp_tp_probs, compute_fp_tp_probs_nd, compute_froc_curve_data, compute_froc_score
from .generalized_dice import GeneralizedDiceScore, compute_generalized_dice
from .hausdorff_distance import HausdorffDistanceMetric, compute_hausdorff_distance, compute_percent_hausdorff_distance
from .loss_metric import LossMetric
from .meandice import DiceHelper, DiceMetric, compute_dice
from .meaniou import MeanIoU, compute_iou, compute_meaniou
from .metric import Cumulative, CumulativeIterationMetric, IterationMetric, Metric
from .mmd import MMDMetric, compute_mmd
from .panoptic_quality import PanopticQualityMetric, compute_panoptic_quality
from .regression import MAEMetric, MSEMetric, PSNRMetric, RMSEMetric, SSIMMetric
from .regression import (
MAEMetric,
MSEMetric,
MultiScaleSSIMMetric,
PSNRMetric,
RMSEMetric,
SSIMMetric,
compute_ms_ssim,
compute_ssim_and_cs,
)
from .rocauc import ROCAUCMetric, compute_roc_auc
from .surface_dice import SurfaceDiceMetric, compute_surface_dice
from .surface_distance import SurfaceDistanceMetric, compute_average_surface_distance
Expand Down
111 changes: 111 additions & 0 deletions monai/metrics/fid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from __future__ import annotations

import numpy as np
import torch

from monai.metrics.metric import Metric
from monai.utils import optional_import

scipy, _ = optional_import("scipy")


class FIDMetric(Metric):
"""
Frechet Inception Distance (FID). The FID calculates the distance between two distributions of feature vectors.
Based on: Heusel M. et al. "Gans trained by a two time-scale update rule converge to a local nash equilibrium."
https://arxiv.org/abs/1706.08500. The inputs for this metric should be two groups of feature vectors (with format
(number images, number of features)) extracted from a pretrained network.

Originally, it was proposed to use the activations of the pool_3 layer of an Inception v3 pretrained with Imagenet.
However, others networks pretrained on medical datasets can be used as well (for example, RadImageNwt for 2D and
MedicalNet for 3D images). If the chosen model output is not a scalar, a global spatia average pooling should be
used.
"""

def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return get_fid_score(y_pred, y)


def get_fid_score(y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
marksgraham marked this conversation as resolved.
Show resolved Hide resolved
"""Computes the FID score metric on a batch of feature vectors.

Args:
y_pred: feature vectors extracted from a pretrained network run on generated images.
y: feature vectors extracted from a pretrained network run on images from the real data distribution.
"""
y = y.double()
y_pred = y_pred.double()

if y.ndimension() > 2:
raise ValueError("Inputs should have (number images, number of features) shape.")

mu_y_pred = torch.mean(y_pred, dim=0)
sigma_y_pred = _cov(y_pred, rowvar=False)
mu_y = torch.mean(y, dim=0)
sigma_y = _cov(y, rowvar=False)

return compute_frechet_distance(mu_y_pred, sigma_y_pred, mu_y, sigma_y)


def _cov(input_data: torch.Tensor, rowvar: bool = True) -> torch.Tensor:
"""
Estimate a covariance matrix of the variables.

Args:
input_data: A 1-D or 2-D array containing multiple variables and observations. Each row of `m` represents a variable,
and each column a single observation of all those variables.
rowvar: If rowvar is True (default), then each row represents a variable, with observations in the columns.
Otherwise, the relationship is transposed: each column represents a variable, while the rows contain
observations.
"""
if input_data.dim() < 2:
input_data = input_data.view(1, -1)

if not rowvar and input_data.size(0) != 1:
input_data = input_data.t()

factor = 1.0 / (input_data.size(1) - 1)
input_data = input_data - torch.mean(input_data, dim=1, keepdim=True)
return factor * input_data.matmul(input_data.t()).squeeze()


def _sqrtm(input_data: torch.Tensor) -> torch.Tensor:
"""Compute the square root of a matrix."""
scipy_res, _ = scipy.linalg.sqrtm(input_data.detach().cpu().numpy().astype(np.float_), disp=False)
return torch.from_numpy(scipy_res)


def compute_frechet_distance(
mu_x: torch.Tensor, sigma_x: torch.Tensor, mu_y: torch.Tensor, sigma_y: torch.Tensor, epsilon: float = 1e-6
) -> torch.Tensor:
"""The Frechet distance between multivariate normal distributions."""
diff = mu_x - mu_y

covmean = _sqrtm(sigma_x.mm(sigma_y))

# Product might be almost singular
if not torch.isfinite(covmean).all():
print(f"FID calculation produces singular product; adding {epsilon} to diagonal of covariance estimates")
offset = torch.eye(sigma_x.size(0), device=mu_x.device, dtype=mu_x.dtype) * epsilon
covmean = _sqrtm((sigma_x + offset).mm(sigma_y + offset))

# Numerical error might give slight imaginary component
if torch.is_complex(covmean):
if not torch.allclose(torch.diagonal(covmean).imag, torch.tensor(0, dtype=torch.double), atol=1e-3):
raise ValueError(f"Imaginary component {torch.max(torch.abs(covmean.imag))} too high.")
covmean = covmean.real

tr_covmean = torch.trace(covmean)
return diff.dot(diff) + torch.trace(sigma_x) + torch.trace(sigma_y) - 2 * tr_covmean
91 changes: 91 additions & 0 deletions monai/metrics/mmd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from collections.abc import Callable

import torch

from monai.metrics.metric import Metric


class MMDMetric(Metric):
"""
Unbiased Maximum Mean Discrepancy (MMD) is a kernel-based method for measuring the similarity between two
distributions. It is a non-negative metric where a smaller value indicates a closer match between the two
distributions.

Gretton, A., et al,, 2012. A kernel two-sample test. The Journal of Machine Learning Research, 13(1), pp.723-773.

Args:
y_mapping: Callable to transform the y tensors before computing the metric. It is usually a Gaussian or Laplace
filter, but it can be any function that takes a tensor as input and returns a tensor as output such as a
feature extractor or an Identity function., e.g. `y_mapping = lambda x: x.square()`.
"""

def __init__(self, y_mapping: Callable | None = None) -> None:
super().__init__()
self.y_mapping = y_mapping

def __call__(self, y: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
return compute_mmd(y, y_pred, self.y_mapping)


def compute_mmd(y: torch.Tensor, y_pred: torch.Tensor, y_mapping: Callable | None) -> torch.Tensor:
"""
Args:
y: first sample (e.g., the reference image). Its shape is (B,C,W,H) for 2D data and (B,C,W,H,D) for 3D.
y_pred: second sample (e.g., the reconstructed image). It has similar shape as y.
y_mapping: Callable to transform the y tensors before computing the metric.
"""
if y_pred.shape[0] == 1 or y.shape[0] == 1:
raise ValueError("MMD metric requires at least two samples in y and y_pred.")

if y_mapping is not None:
y = y_mapping(y)
y_pred = y_mapping(y_pred)

if y_pred.shape != y.shape:
raise ValueError(
"y_pred and y shapes dont match after being processed "
f"by their transforms, received y_pred: {y_pred.shape} and y: {y.shape}"
)

for d in range(len(y.shape) - 1, 1, -1):
y = y.squeeze(dim=d)
y_pred = y_pred.squeeze(dim=d)

y = y.view(y.shape[0], -1)
y_pred = y_pred.view(y_pred.shape[0], -1)

y_y = torch.mm(y, y.t())
y_pred_y_pred = torch.mm(y_pred, y_pred.t())
y_pred_y = torch.mm(y_pred, y.t())

m = y.shape[0]
n = y_pred.shape[0]

# Ref. 1 Eq. 3 (found under Lemma 6)
marksgraham marked this conversation as resolved.
Show resolved Hide resolved
# term 1
c1 = 1 / (m * (m - 1))
a = torch.sum(y_y - torch.diag(torch.diagonal(y_y)))

# term 2
c2 = 1 / (n * (n - 1))
b = torch.sum(y_pred_y_pred - torch.diag(torch.diagonal(y_pred_y_pred)))

# term 3
c3 = 2 / (m * n)
c = torch.sum(y_pred_y)

mmd = c1 * a + c2 * b - c3 * c
return mmd
Loading