diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index a7f90b754d96..ef1250314811 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -355,6 +355,7 @@ Complete list of metrics Entropy KLDivergence JSDivergence + MaximumMeanDiscrepancy AveragePrecision CohenKappa GpuInfo diff --git a/ignite/metrics/__init__.py b/ignite/metrics/__init__.py index 2cc55aace661..e4f4e24337c5 100644 --- a/ignite/metrics/__init__.py +++ b/ignite/metrics/__init__.py @@ -17,6 +17,7 @@ from ignite.metrics.js_divergence import JSDivergence from ignite.metrics.kl_divergence import KLDivergence from ignite.metrics.loss import Loss +from ignite.metrics.maximum_mean_discrepancy import MaximumMeanDiscrepancy from ignite.metrics.mean_absolute_error import MeanAbsoluteError from ignite.metrics.mean_pairwise_distance import MeanPairwiseDistance from ignite.metrics.mean_squared_error import MeanSquaredError @@ -61,6 +62,7 @@ "JaccardIndex", "JSDivergence", "KLDivergence", + "MaximumMeanDiscrepancy", "MultiLabelConfusionMatrix", "MutualInformation", "Precision", diff --git a/ignite/metrics/maximum_mean_discrepancy.py b/ignite/metrics/maximum_mean_discrepancy.py new file mode 100644 index 000000000000..d92dd5448ce1 --- /dev/null +++ b/ignite/metrics/maximum_mean_discrepancy.py @@ -0,0 +1,138 @@ +from typing import Callable, Sequence + +import torch + +from ignite.exceptions import NotComputableError +from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce + +__all__ = ["MaximumMeanDiscrepancy"] + + +class MaximumMeanDiscrepancy(Metric): + r"""Calculates the mean of `maximum mean discrepancy (MMD) + `_. + + .. math:: + \begin{align*} + \text{MMD}^2 (P,Q) &= \underset{\| f \| \leq 1}{\text{sup}} | \mathbb{E}_{X\sim P}[f(X)] + - \mathbb{E}_{Y\sim Q}[f(Y)] |^2 \\ + &\approx \frac{1}{B(B-1)} \sum_{i=1}^B \sum_{\substack{j=1 \\ j\neq i}}^B k(\mathbf{x}_i,\mathbf{x}_j) + -\frac{2}{B^2}\sum_{i=1}^B \sum_{j=1}^B k(\mathbf{x}_i,\mathbf{y}_j) + + \frac{1}{B(B-1)} \sum_{i=1}^B \sum_{\substack{j=1 \\ j\neq i}}^B k(\mathbf{y}_i,\mathbf{y}_j) + \end{align*} + + where :math:`B` is the batch size, and :math:`\mathbf{x}_i` and :math:`\mathbf{y}_j` are + feature vectors sampled from :math:`P` and :math:`Q`, respectively. + :math:`k(\mathbf{x},\mathbf{y})=\exp(-\| \mathbf{x}-\mathbf{y} \|^2/ 2\sigma^2)` is the Gaussian RBF kernel. + + This metric computes the MMD for each batch and takes the average. + + More details can be found in `Gretton et al. 2012`__. + + __ https://jmlr.csail.mit.edu/papers/v13/gretton12a.html + + - ``update`` must receive output of the form ``(x, y)``. + - ``x`` and ``y`` are expected to be in the same shape :math:`(B, \ldots)`. + + Args: + var: the bandwidth :math:`\sigma^2` of the kernel. Default: 1.0 + output_transform: a callable that is used to transform the + :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the + form expected by the metric. This can be useful if, for example, you have a multi-output model and + you want to compute the metric with respect to one of the outputs. + By default, this metric requires the output as ``(x, y)``. + device: specifies which device updates are accumulated on. Setting the + metric's device to be the same as your ``update`` arguments ensures the ``update`` method is + non-blocking. By default, CPU. + + Examples: + To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine. + The output of the engine's ``process_function`` needs to be in the format of + ``(x, y)``. If not, ``output_tranform`` can be added + to the metric to transform the output into the form expected by the metric. + + For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`. + + .. include:: defaults.rst + :start-after: :orphan: + + .. testcode:: + + metric = MaximumMeanDiscrepancy() + metric.attach(default_evaluator, "mmd") + x = torch.tensor([[-0.80324818, -0.95768364, -0.03807209], + [-0.11059691, -0.38230813, -0.4111988], + [-0.8864329, -0.02890403, -0.60119252], + [-0.68732452, -0.12854739, -0.72095073], + [-0.62604613, -0.52368328, -0.24112842]]) + y = torch.tensor([[0.0686768, 0.80502737, 0.53321717], + [0.83849465, 0.59099726, 0.76385441], + [0.68688272, 0.56833803, 0.98100778], + [0.55267761, 0.13084654, 0.45382906], + [0.0754253, 0.70317304, 0.4756805]]) + state = default_evaluator.run([[x, y]]) + print(state.metrics["mmd"]) + + .. testoutput:: + + 1.0726975202560425 + """ + + _state_dict_all_req_keys = ("_xx_sum", "_yy_sum", "_xy_sum", "_num_batches") + + def __init__( + self, var: float = 1.0, output_transform: Callable = lambda x: x, device: torch.device = torch.device("cpu") + ): + self.var = var + super().__init__(output_transform, device) + + @reinit__is_reduced + def reset(self) -> None: + self._xx_sum = torch.tensor(0.0, device=self._device) + self._yy_sum = torch.tensor(0.0, device=self._device) + self._xy_sum = torch.tensor(0.0, device=self._device) + self._num_batches = 0 + + @reinit__is_reduced + def update(self, output: Sequence[torch.Tensor]) -> None: + x, y = output[0].detach(), output[1].detach() + if x.shape != y.shape: + raise ValueError(f"x and y must be in the same shape, got {x.shape} != {y.shape}.") + + if x.ndim >= 3: + x = x.flatten(start_dim=1) + y = y.flatten(start_dim=1) + elif x.ndim == 1: + raise ValueError(f"x must be in the shape of (B, ...), got {x.shape}.") + + xx, yy, zz = torch.mm(x, x.t()), torch.mm(y, y.t()), torch.mm(x, y.t()) + rx = xx.diag().unsqueeze(0).expand_as(xx) + ry = yy.diag().unsqueeze(0).expand_as(yy) + + dxx = rx.t() + rx - 2.0 * xx + dyy = ry.t() + ry - 2.0 * yy + dxy = rx.t() + ry - 2.0 * zz + + v = self.var + XX = torch.exp(-0.5 * dxx / v) + YY = torch.exp(-0.5 * dyy / v) + XY = torch.exp(-0.5 * dxy / v) + + # unbiased + n = x.shape[0] + XX = (XX.sum() - n) / (n * (n - 1)) + YY = (YY.sum() - n) / (n * (n - 1)) + XY = XY.sum() / (n * n) + + self._xx_sum += XX.to(self._device) + self._yy_sum += YY.to(self._device) + self._xy_sum += XY.to(self._device) + + self._num_batches += 1 + + @sync_all_reduce("_xx_sum", "_yy_sum", "_xy_sum", "_num_batches") + def compute(self) -> float: + if self._num_batches == 0: + raise NotComputableError("MaximumMeanDiscrepacy must have at least one batch before it can be computed.") + mmd2 = (self._xx_sum + self._yy_sum - 2.0 * self._xy_sum).clamp(min=0.0) / self._num_batches + return mmd2.sqrt().item() diff --git a/tests/ignite/metrics/test_maximum_mean_discrepancy.py b/tests/ignite/metrics/test_maximum_mean_discrepancy.py new file mode 100644 index 000000000000..8cfc5f55567d --- /dev/null +++ b/tests/ignite/metrics/test_maximum_mean_discrepancy.py @@ -0,0 +1,176 @@ +from typing import Tuple + +import numpy as np +import pytest +import torch +from torch import Tensor + +import ignite.distributed as idist +from ignite.engine import Engine +from ignite.exceptions import NotComputableError +from ignite.metrics import MaximumMeanDiscrepancy + + +def np_mmd2(x: np.ndarray, y: np.ndarray, var: float = 1.0): + n = x.shape[0] + x = x.reshape(n, -1) + y = y.reshape(n, -1) + + a = np.arange(n) + ii, jj = np.meshgrid(a, a, indexing="ij") + XX = np.exp(-np.square(x[ii] - x[jj]).sum(axis=2) / (var * 2)) + XX = (np.sum(XX) - n) / (n * (n - 1)) + + XY = np.exp(-np.square(x[ii] - y[jj]).sum(axis=2) / (var * 2)) + XY = np.sum(XY) / (n * n) + + YY = np.exp(-np.square(y[ii] - y[jj]).sum(axis=2) / (var * 2)) + YY = (np.sum(YY) - n) / (n * (n - 1)) + + mmd2 = np.clip(XX + YY - XY * 2, 0.0, None) + return mmd2 + + +def test_zero_sample(): + mmd = MaximumMeanDiscrepancy() + with pytest.raises( + NotComputableError, match=r"MaximumMeanDiscrepacy must have at least one batch before it can be computed" + ): + mmd.compute() + + +def test_shape_mismatch(): + mmd = MaximumMeanDiscrepancy() + x = torch.tensor([[2.0, 3.0], [-2.0, 1.0]], dtype=torch.float) + y = torch.tensor([[-2.0, 1.0]], dtype=torch.float) + with pytest.raises(ValueError, match=r"x and y must be in the same shape, got"): + mmd.update((x, y)) + + +def test_invalid_shape(): + mmd = MaximumMeanDiscrepancy() + x = torch.tensor([2.0, 3.0], dtype=torch.float) + y = torch.tensor([4.0, 5.0], dtype=torch.float) + with pytest.raises(ValueError, match=r"x must be in the shape of \(B, ...\), got"): + mmd.update((x, y)) + + +@pytest.fixture(params=list(range(4))) +def test_case(request): + return [ + (torch.randn((100, 10)), torch.rand((100, 10)), 10 ** np.random.uniform(-1.0, 0.0), 1), + (torch.rand((100, 500)), torch.randn((100, 500)), 10 ** np.random.uniform(-1.0, 0.0), 1), + # updated batches + (torch.normal(0.0, 5.0, size=(100, 10)), torch.rand((100, 10)), 10 ** np.random.uniform(-1.0, 0.0), 16), + (torch.normal(5.0, 3.0, size=(100, 200)), torch.rand((100, 200)), 10 ** np.random.uniform(-1.0, 0.0), 16), + # image segmentation + (torch.randn((100, 5, 32, 32)), torch.rand((100, 5, 32, 32)), 10 ** np.random.uniform(-1.0, 0.0), 32), + (torch.rand((100, 5, 224, 224)), torch.randn((100, 5, 224, 224)), 10 ** np.random.uniform(-1.0, 0.0), 32), + ][request.param] + + +@pytest.mark.parametrize("n_times", range(5)) +def test_compute(n_times, test_case: Tuple[Tensor, Tensor, float, int]): + x, y, var, batch_size = test_case + + mmd = MaximumMeanDiscrepancy(var=var) + mmd.reset() + + if batch_size > 1: + np_mmd2_sum = 0.0 + n_iters = y.shape[0] // batch_size + 1 + for i in range(n_iters): + idx = i * batch_size + x_batch, y_batch = x[idx : idx + batch_size], y[idx : idx + batch_size] + mmd.update((x_batch, y_batch)) + + np_mmd2_sum += np_mmd2(x_batch.cpu().numpy(), y_batch.cpu().numpy(), var) + + np_res = np.sqrt(np_mmd2_sum / n_iters) + else: + mmd.update((x, y)) + np_res = np.sqrt(np_mmd2(x.cpu().numpy(), y.cpu().numpy(), var)) + + res = mmd.compute() + + assert isinstance(res, float) + assert pytest.approx(np_res, abs=1e-4) == res + + +def test_accumulator_detached(): + mmd = MaximumMeanDiscrepancy() + + x = torch.tensor([[2.0, 3.0], [-2.0, 1.0]], dtype=torch.float) + y = torch.tensor([[-2.0, 1.0], [2.0, 3.0]], dtype=torch.float) + mmd.update((x, y)) + + assert not any(acc.requires_grad for acc in (mmd._xx_sum, mmd._yy_sum, mmd._xy_sum)) + + +@pytest.mark.usefixtures("distributed") +class TestDistributed: + def test_integration(self): + tol = 1e-4 + n_iters = 100 + batch_size = 10 + n_dims = 100 + + rank = idist.get_rank() + torch.manual_seed(12 + rank) + + device = idist.device() + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(device) + + for metric_device in metric_devices: + y = torch.randn((n_iters * batch_size, n_dims)).float().to(device) + x = torch.normal(2.0, 3.0, size=(n_iters * batch_size, n_dims)).float().to(device) + + def data_loader(i): + return x[i * batch_size : (i + 1) * batch_size], y[i * batch_size : (i + 1) * batch_size] + + engine = Engine(lambda e, i: data_loader(i)) + + m = MaximumMeanDiscrepancy(device=metric_device) + m.attach(engine, "mmd") + + data = list(range(n_iters)) + engine.run(data=data, max_epochs=1) + + x = idist.all_gather(x) + y = idist.all_gather(y) + + assert "mmd" in engine.state.metrics + res = engine.state.metrics["mmd"] + + # compute numpy mmd + true_res = 0.0 + for i in range(n_iters): + x_batch, y_batch = data_loader(i) + x_np = x_batch.cpu().numpy() + y_np = y_batch.cpu().numpy() + true_res += np_mmd2(x_np, y_np) + + true_res = np.sqrt(true_res / n_iters) + assert pytest.approx(true_res, abs=tol) == res + + def test_accumulator_device(self): + device = idist.device() + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(device) + for metric_device in metric_devices: + mmd = MaximumMeanDiscrepancy(device=metric_device) + + devices = (mmd._device, mmd._xx_sum.device, mmd._yy_sum.device, mmd._xy_sum.device) + for dev in devices: + assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}" + + x = torch.tensor([[2.0, 3.0], [-2.0, 1.0]]).float() + y = torch.ones(2, 2).float() + mmd.update((x, y)) + + devices = (mmd._device, mmd._xx_sum.device, mmd._yy_sum.device, mmd._xy_sum.device) + for dev in devices: + assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}"