From 9ace0b48e9fa636f4f3f00d450923c5afe748494 Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Wed, 1 May 2024 23:12:45 +0900 Subject: [PATCH 01/10] add MaximumMeanDiscrepancy metric --- docs/source/metrics.rst | 1 + ignite/metrics/__init__.py | 2 + ignite/metrics/maximum_mean_discrepancy.py | 131 +++++++++++++ .../metrics/test_maximum_mean_discrepancy.py | 175 ++++++++++++++++++ 4 files changed, 309 insertions(+) create mode 100644 ignite/metrics/maximum_mean_discrepancy.py create mode 100644 tests/ignite/metrics/test_maximum_mean_discrepancy.py diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index a7f90b754d96..ef1250314811 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -355,6 +355,7 @@ Complete list of metrics Entropy KLDivergence JSDivergence + MaximumMeanDiscrepancy AveragePrecision CohenKappa GpuInfo diff --git a/ignite/metrics/__init__.py b/ignite/metrics/__init__.py index 2cc55aace661..e4f4e24337c5 100644 --- a/ignite/metrics/__init__.py +++ b/ignite/metrics/__init__.py @@ -17,6 +17,7 @@ from ignite.metrics.js_divergence import JSDivergence from ignite.metrics.kl_divergence import KLDivergence from ignite.metrics.loss import Loss +from ignite.metrics.maximum_mean_discrepancy import MaximumMeanDiscrepancy from ignite.metrics.mean_absolute_error import MeanAbsoluteError from ignite.metrics.mean_pairwise_distance import MeanPairwiseDistance from ignite.metrics.mean_squared_error import MeanSquaredError @@ -61,6 +62,7 @@ "JaccardIndex", "JSDivergence", "KLDivergence", + "MaximumMeanDiscrepancy", "MultiLabelConfusionMatrix", "MutualInformation", "Precision", diff --git a/ignite/metrics/maximum_mean_discrepancy.py b/ignite/metrics/maximum_mean_discrepancy.py new file mode 100644 index 000000000000..e5edc0f801a7 --- /dev/null +++ b/ignite/metrics/maximum_mean_discrepancy.py @@ -0,0 +1,131 @@ +from typing import Callable, Sequence + +import torch + +from ignite.exceptions import NotComputableError +from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce + +__all__ = ["MaximumMeanDiscrepancy"] + + +class MaximumMeanDiscrepancy(Metric): + r"""Calculates the mean of `maximum mean discrepancy (MMD) + `_. + + .. math:: + \begin{align*} + \text{MMD}^2 (P,Q) &= \underset{\| f \| \leq 1}{\text{sup}} | \mathbb{E}_{X\sim P}[f(X)] + - \mathbb{E}_{Y\sim Q}[f(Y)] |^2 \\ + &\approx \frac{1}{B(B-1)} \sum_{i} \sum_{j\neq i}k(\mathbf{x}_i,\mathbf{x}_j) + -\frac{2}{B^2}\sum_{i} \sum_{j} k(\mathbf{x}_i,\mathbf{y}_j) + + \frac{1}{B(B-1)} \sum_{i} \sum_{j\neq i}k(\mathbf{y}_i,\mathbf{y}_j) + \end{align*} + + where :math:`B` is the batch size, and :math:`\mathbf{x}_i` and :math:`\mathbf{y}_j` are + feature vectors sampled from :math:`P` and :math:`Q`, respectively. + :math:`k(\mathbf{x},\mathbf{y})=\exp(-\| \mathbf{x}-\mathbf{y} \|^2/ 2\sigma^2)` is the Gaussian RBF kernel. + + This metric computes the MMD for each batch and takes the average. + + - ``update`` must receive output of the form ``(x, y)``. + - ``x`` and ``y`` are expected to be in the same shape :math:`(B, \ldots)`. + + Args: + var: the bandwidth :math:`\sigma^2` of the kernel. Default: 1.0 + output_transform: a callable that is used to transform the + :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the + form expected by the metric. This can be useful if, for example, you have a multi-output model and + you want to compute the metric with respect to one of the outputs. + By default, this metric requires the output as ``(x, y)``. + device: specifies which device updates are accumulated on. Setting the + metric's device to be the same as your ``update`` arguments ensures the ``update`` method is + non-blocking. By default, CPU. + + Examples: + To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine. + The output of the engine's ``process_function`` needs to be in the format of + ``(x, y)``. If not, ``output_tranform`` can be added + to the metric to transform the output into the form expected by the metric. + + For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`. + + .. include:: defaults.rst + :start-after: :orphan: + + .. testcode:: + + metric = MaximumMeanDiscrepancy() + metric.attach(default_evaluator, "mmd") + x = torch.tensor([[-0.80324818, -0.95768364, -0.03807209], + [-0.11059691, -0.38230813, -0.4111988], + [-0.8864329, -0.02890403, -0.60119252], + [-0.68732452, -0.12854739, -0.72095073], + [-0.62604613, -0.52368328, -0.24112842]]) + y = torch.tensor([[0.0686768, 0.80502737, 0.53321717], + [0.83849465, 0.59099726, 0.76385441], + [0.68688272, 0.56833803, 0.98100778], + [0.55267761, 0.13084654, 0.45382906], + [0.0754253, 0.70317304, 0.4756805]]) + state = default_evaluator.run([[x, y]]) + print(state.metrics["mmd"]) + + .. testoutput:: + + 1.0726975202560425 + """ + + _state_dict_all_req_keys = ("_sum_of_mmd", "_num_batches") + + def __init__( + self, var: float = 1.0, output_transform: Callable = lambda x: x, device: torch.device = torch.device("cpu") + ): + self.var = var + super().__init__(output_transform, device) + + @reinit__is_reduced + def reset(self) -> None: + self._sum_of_mmd = torch.tensor(0.0, device=self._device) + self._num_batches = 0 + + @reinit__is_reduced + def update(self, output: Sequence[torch.Tensor]) -> None: + x, y = output[0].detach(), output[1].detach() + if x.shape != y.shape: + raise ValueError(f"x and y must be in the same shape, got {x.shape} != {y.shape}.") + + if x.ndim >= 3: + x = x.flatten(start_dim=1) + y = y.flatten(start_dim=1) + elif x.ndim == 1: + raise ValueError(f"x must be in the shape of (B, ...), got {x.shape}.") + + xx, yy, zz = torch.mm(x, x.t()), torch.mm(y, y.t()), torch.mm(x, y.t()) + rx = xx.diag().unsqueeze(0).expand_as(xx) + ry = yy.diag().unsqueeze(0).expand_as(yy) + + dxx = rx.t() + rx - 2.0 * xx + dyy = ry.t() + ry - 2.0 * yy + dxy = rx.t() + ry - 2.0 * zz + + v = self.var + XX = torch.exp(-0.5 * dxx / v) + YY = torch.exp(-0.5 * dyy / v) + XY = torch.exp(-0.5 * dxy / v) + + # unbiased + n = x.shape[0] + XX = (XX.sum() - n) / (n * (n - 1)) + YY = (YY.sum() - n) / (n * (n - 1)) + XY = XY.sum() / (n * n) + + # mmd cannot be negative + mmd2 = (XX - 2.0 * XY + YY).clamp(min=0.0) + + self._sum_of_mmd += mmd2.sqrt().to(self._device) + self._num_batches += 1 + + @sync_all_reduce("_sum_of_mmd", "_num_batches") + def compute(self) -> float: + if self._num_batches == 0: + raise NotComputableError("MaximumMeanDiscrepacy must have at least one batch before it can be computed.") + return self._sum_of_mmd.item() / self._num_batches diff --git a/tests/ignite/metrics/test_maximum_mean_discrepancy.py b/tests/ignite/metrics/test_maximum_mean_discrepancy.py new file mode 100644 index 000000000000..421efb4e08e3 --- /dev/null +++ b/tests/ignite/metrics/test_maximum_mean_discrepancy.py @@ -0,0 +1,175 @@ +from typing import Tuple + +import numpy as np +import pytest +import torch +from torch import Tensor + +import ignite.distributed as idist +from ignite.engine import Engine +from ignite.exceptions import NotComputableError +from ignite.metrics import MaximumMeanDiscrepancy + + +def np_mmd(x: np.ndarray, y: np.ndarray, var: float): + n = x.shape[0] + x = x.reshape(n, -1) + y = y.reshape(n, -1) + + a = np.arange(n) + ii, jj = np.meshgrid(a, a, indexing="ij") + XX = np.exp(-np.square(x[ii] - x[jj]).sum(axis=2) / (var * 2)) + XX = (np.sum(XX) - n) / (n * (n - 1)) + + XY = np.exp(-np.square(x[ii] - y[jj]).sum(axis=2) / (var * 2)) + XY = np.sum(XY) / (n * n) + + YY = np.exp(-np.square(y[ii] - y[jj]).sum(axis=2) / (var * 2)) + YY = (np.sum(YY) - n) / (n * (n - 1)) + + mmd2 = np.clip(XX + YY - XY * 2, 0.0, None) + + return np.sqrt(mmd2) + + +def test_zero_sample(): + mmd = MaximumMeanDiscrepancy() + with pytest.raises( + NotComputableError, match=r"MaximumMeanDiscrepacy must have at least one batch before it can be computed" + ): + mmd.compute() + + +def test_shape_mismatch(): + mmd = MaximumMeanDiscrepancy() + x = torch.tensor([[2.0, 3.0], [-2.0, 1.0]], dtype=torch.float) + y = torch.tensor([[-2.0, 1.0]], dtype=torch.float) + with pytest.raises(ValueError, match=r"x and y must be in the same shape, got"): + mmd.update((x, y)) + + +def test_invalid_shape(): + mmd = MaximumMeanDiscrepancy() + x = torch.tensor([2.0, 3.0], dtype=torch.float) + y = torch.tensor([4.0, 5.0], dtype=torch.float) + with pytest.raises(ValueError, match=r"x must be in the shape of \(B, ...\), got"): + mmd.update((x, y)) + + +@pytest.fixture(params=list(range(4))) +def test_case(request): + return [ + (torch.randn((100, 10)), torch.rand((100, 10)), 10 ** np.random.uniform(-1.0, 0.0), 1), + (torch.rand((100, 500)), torch.randn((100, 500)), 10 ** np.random.uniform(-1.0, 0.0), 1), + # updated batches + (torch.normal(0.0, 5.0, size=(100, 10)), torch.rand((100, 10)), 10 ** np.random.uniform(-1.0, 0.0), 16), + (torch.normal(5.0, 3.0, size=(100, 200)), torch.rand((100, 200)), 10 ** np.random.uniform(-1.0, 0.0), 16), + # image segmentation + (torch.randn((100, 5, 32, 32)), torch.rand((100, 5, 32, 32)), 10 ** np.random.uniform(-1.0, 0.0), 32), + (torch.rand((100, 5, 224, 224)), torch.randn((100, 5, 224, 224)), 10 ** np.random.uniform(-1.0, 0.0), 32), + ][request.param] + + +@pytest.mark.parametrize("n_times", range(5)) +def test_compute(n_times, test_case: Tuple[Tensor, Tensor, float, int]): + x, y, var, batch_size = test_case + + mmd = MaximumMeanDiscrepancy(var=var) + mmd.reset() + + if batch_size > 1: + np_mmd_sum = 0.0 + n_iters = y.shape[0] // batch_size + 1 + for i in range(n_iters): + idx = i * batch_size + x_batch, y_batch = x[idx : idx + batch_size], y[idx : idx + batch_size] + mmd.update((x_batch, y_batch)) + + np_mmd_sum += np_mmd(x_batch.cpu().numpy(), y_batch.cpu().numpy(), var) + + np_res = np_mmd_sum / n_iters + else: + mmd.update((x, y)) + np_res = np_mmd(x.cpu().numpy(), y.cpu().numpy(), var) + + res = mmd.compute() + + assert isinstance(res, float) + assert pytest.approx(np_res, abs=1e-5) == res + + +def test_accumulator_detached(): + mmd = MaximumMeanDiscrepancy() + + x = torch.tensor([[2.0, 3.0], [-2.0, 1.0]], dtype=torch.float) + y = torch.tensor([[-2.0, 1.0], [2.0, 3.0]], dtype=torch.float) + mmd.update((x, y)) + + assert not mmd._sum_of_mmd.requires_grad + + +@pytest.mark.usefixtures("distributed") +class TestDistributed: + def test_integration(self): + tol = 1e-5 + n_iters = 100 + batch_size = 10 + n_dims = 100 + + rank = idist.get_rank() + torch.manual_seed(12 + rank) + + device = idist.device() + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(device) + + for metric_device in metric_devices: + y = torch.randn((n_iters * batch_size, n_dims)).float().to(device) + x = torch.normal(2.0, 3.0, size=(n_iters * batch_size, n_dims)).float().to(device) + + def data_loader(i): + return x[i * batch_size : (i + 1) * batch_size], y[i * batch_size : (i + 1) * batch_size] + + engine = Engine(lambda e, i: data_loader(i)) + + m = MaximumMeanDiscrepancy(device=metric_device) + m.attach(engine, "mmd") + + data = list(range(n_iters)) + engine.run(data=data, max_epochs=1) + + x = idist.all_gather(x) + y = idist.all_gather(y) + + assert "mmd" in engine.state.metrics + res = engine.state.metrics["mmd"] + + # compute numpy mmd + true_res = 0.0 + for i in range(n_iters): + x_batch, y_batch = data_loader(i) + x_np = x_batch.cpu().numpy() + y_np = y_batch.cpu().numpy() + true_res += np_mmd(x_np, y_np) + + true_res /= n_iters + assert pytest.approx(true_res, abs=tol) == res + + def test_accumulator_device(self): + device = idist.device() + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(device) + for metric_device in metric_devices: + mmd = MaximumMeanDiscrepancy(device=metric_device) + + for dev in (mmd._device, mmd._sum_of_mmd.device): + assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}" + + x = torch.tensor([[2.0, 3.0], [-2.0, 1.0]]).float() + y = torch.ones(2, 2).float() + mmd.update((x, y)) + + for dev in (mmd._device, mmd._sum_of_mmd.device): + assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}" From 59f0386fde0cb09783fe580802bee6a3575176cf Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Wed, 1 May 2024 23:26:23 +0900 Subject: [PATCH 02/10] fix URL --- ignite/metrics/maximum_mean_discrepancy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ignite/metrics/maximum_mean_discrepancy.py b/ignite/metrics/maximum_mean_discrepancy.py index e5edc0f801a7..4bce7c129530 100644 --- a/ignite/metrics/maximum_mean_discrepancy.py +++ b/ignite/metrics/maximum_mean_discrepancy.py @@ -10,7 +10,7 @@ class MaximumMeanDiscrepancy(Metric): r"""Calculates the mean of `maximum mean discrepancy (MMD) - `_. + `_. .. math:: \begin{align*} From dfc495dca1331e251ba62d825b91ed2a3cebef40 Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Fri, 3 May 2024 21:27:14 +0900 Subject: [PATCH 03/10] update formula --- ignite/metrics/maximum_mean_discrepancy.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ignite/metrics/maximum_mean_discrepancy.py b/ignite/metrics/maximum_mean_discrepancy.py index 4bce7c129530..8e0af9f6276c 100644 --- a/ignite/metrics/maximum_mean_discrepancy.py +++ b/ignite/metrics/maximum_mean_discrepancy.py @@ -16,9 +16,9 @@ class MaximumMeanDiscrepancy(Metric): \begin{align*} \text{MMD}^2 (P,Q) &= \underset{\| f \| \leq 1}{\text{sup}} | \mathbb{E}_{X\sim P}[f(X)] - \mathbb{E}_{Y\sim Q}[f(Y)] |^2 \\ - &\approx \frac{1}{B(B-1)} \sum_{i} \sum_{j\neq i}k(\mathbf{x}_i,\mathbf{x}_j) - -\frac{2}{B^2}\sum_{i} \sum_{j} k(\mathbf{x}_i,\mathbf{y}_j) - + \frac{1}{B(B-1)} \sum_{i} \sum_{j\neq i}k(\mathbf{y}_i,\mathbf{y}_j) + &\approx \frac{1}{B(B-1)} \sum_{i=1}^B \sum_{\substack{j=1 \\ j\neq i}}^B k(\mathbf{x}_i,\mathbf{x}_j) + -\frac{2}{B^2}\sum_{i=1}^B \sum_{j=1}^B k(\mathbf{x}_i,\mathbf{y}_j) + + \frac{1}{B(B-1)} \sum_{i=1}^B \sum_{\substack{j=1 \\ j\neq i}}^B k(\mathbf{y}_i,\mathbf{y}_j) \end{align*} where :math:`B` is the batch size, and :math:`\mathbf{x}_i` and :math:`\mathbf{y}_j` are From c5eda6e0b798131cc46ce3a86e322adaa51e127f Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Fri, 3 May 2024 21:40:45 +0900 Subject: [PATCH 04/10] modify test for MMD --- tests/ignite/metrics/test_maximum_mean_discrepancy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/ignite/metrics/test_maximum_mean_discrepancy.py b/tests/ignite/metrics/test_maximum_mean_discrepancy.py index 421efb4e08e3..0dc7e18d9d11 100644 --- a/tests/ignite/metrics/test_maximum_mean_discrepancy.py +++ b/tests/ignite/metrics/test_maximum_mean_discrepancy.py @@ -95,7 +95,7 @@ def test_compute(n_times, test_case: Tuple[Tensor, Tensor, float, int]): res = mmd.compute() assert isinstance(res, float) - assert pytest.approx(np_res, abs=1e-5) == res + assert pytest.approx(np_res, abs=1e-4) == res def test_accumulator_detached(): @@ -111,7 +111,7 @@ def test_accumulator_detached(): @pytest.mark.usefixtures("distributed") class TestDistributed: def test_integration(self): - tol = 1e-5 + tol = 1e-4 n_iters = 100 batch_size = 10 n_dims = 100 From c3346c59d9c3a0d2aa4dbe711ec986c72496e93e Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Sat, 4 May 2024 00:50:28 +0900 Subject: [PATCH 05/10] set default var value for np_mmd --- tests/ignite/metrics/test_maximum_mean_discrepancy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/ignite/metrics/test_maximum_mean_discrepancy.py b/tests/ignite/metrics/test_maximum_mean_discrepancy.py index 0dc7e18d9d11..c6400393296b 100644 --- a/tests/ignite/metrics/test_maximum_mean_discrepancy.py +++ b/tests/ignite/metrics/test_maximum_mean_discrepancy.py @@ -11,7 +11,7 @@ from ignite.metrics import MaximumMeanDiscrepancy -def np_mmd(x: np.ndarray, y: np.ndarray, var: float): +def np_mmd(x: np.ndarray, y: np.ndarray, var: float = 1.0): n = x.shape[0] x = x.reshape(n, -1) y = y.reshape(n, -1) From 32ad1dbbf2b93c6854c649b7fb86b2499bf9ace4 Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Mon, 6 May 2024 16:12:51 +0900 Subject: [PATCH 06/10] accumulate mmd2 --- ignite/metrics/maximum_mean_discrepancy.py | 10 ++++---- .../metrics/test_maximum_mean_discrepancy.py | 23 +++++++++---------- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/ignite/metrics/maximum_mean_discrepancy.py b/ignite/metrics/maximum_mean_discrepancy.py index 8e0af9f6276c..5afb9e0461c9 100644 --- a/ignite/metrics/maximum_mean_discrepancy.py +++ b/ignite/metrics/maximum_mean_discrepancy.py @@ -74,7 +74,7 @@ class MaximumMeanDiscrepancy(Metric): 1.0726975202560425 """ - _state_dict_all_req_keys = ("_sum_of_mmd", "_num_batches") + _state_dict_all_req_keys = ("_sum_of_mmd2", "_num_batches") def __init__( self, var: float = 1.0, output_transform: Callable = lambda x: x, device: torch.device = torch.device("cpu") @@ -84,7 +84,7 @@ def __init__( @reinit__is_reduced def reset(self) -> None: - self._sum_of_mmd = torch.tensor(0.0, device=self._device) + self._sum_of_mmd2 = torch.tensor(0.0, device=self._device) self._num_batches = 0 @reinit__is_reduced @@ -121,11 +121,11 @@ def update(self, output: Sequence[torch.Tensor]) -> None: # mmd cannot be negative mmd2 = (XX - 2.0 * XY + YY).clamp(min=0.0) - self._sum_of_mmd += mmd2.sqrt().to(self._device) + self._sum_of_mmd2 += mmd2.to(self._device) self._num_batches += 1 - @sync_all_reduce("_sum_of_mmd", "_num_batches") + @sync_all_reduce("_sum_of_mmd2", "_num_batches") def compute(self) -> float: if self._num_batches == 0: raise NotComputableError("MaximumMeanDiscrepacy must have at least one batch before it can be computed.") - return self._sum_of_mmd.item() / self._num_batches + return (self._sum_of_mmd2 / self._num_batches).sqrt().item() diff --git a/tests/ignite/metrics/test_maximum_mean_discrepancy.py b/tests/ignite/metrics/test_maximum_mean_discrepancy.py index c6400393296b..ed43439a4335 100644 --- a/tests/ignite/metrics/test_maximum_mean_discrepancy.py +++ b/tests/ignite/metrics/test_maximum_mean_discrepancy.py @@ -11,7 +11,7 @@ from ignite.metrics import MaximumMeanDiscrepancy -def np_mmd(x: np.ndarray, y: np.ndarray, var: float = 1.0): +def np_mmd2(x: np.ndarray, y: np.ndarray, var: float = 1.0): n = x.shape[0] x = x.reshape(n, -1) y = y.reshape(n, -1) @@ -28,8 +28,7 @@ def np_mmd(x: np.ndarray, y: np.ndarray, var: float = 1.0): YY = (np.sum(YY) - n) / (n * (n - 1)) mmd2 = np.clip(XX + YY - XY * 2, 0.0, None) - - return np.sqrt(mmd2) + return mmd2 def test_zero_sample(): @@ -78,19 +77,19 @@ def test_compute(n_times, test_case: Tuple[Tensor, Tensor, float, int]): mmd.reset() if batch_size > 1: - np_mmd_sum = 0.0 + np_mmd2_sum = 0.0 n_iters = y.shape[0] // batch_size + 1 for i in range(n_iters): idx = i * batch_size x_batch, y_batch = x[idx : idx + batch_size], y[idx : idx + batch_size] mmd.update((x_batch, y_batch)) - np_mmd_sum += np_mmd(x_batch.cpu().numpy(), y_batch.cpu().numpy(), var) + np_mmd2_sum += np_mmd2(x_batch.cpu().numpy(), y_batch.cpu().numpy(), var) - np_res = np_mmd_sum / n_iters + np_res = np.sqrt(np_mmd2_sum / n_iters) else: mmd.update((x, y)) - np_res = np_mmd(x.cpu().numpy(), y.cpu().numpy(), var) + np_res = np.sqrt(np_mmd2(x.cpu().numpy(), y.cpu().numpy(), var)) res = mmd.compute() @@ -105,7 +104,7 @@ def test_accumulator_detached(): y = torch.tensor([[-2.0, 1.0], [2.0, 3.0]], dtype=torch.float) mmd.update((x, y)) - assert not mmd._sum_of_mmd.requires_grad + assert not mmd._sum_of_mmd2.requires_grad @pytest.mark.usefixtures("distributed") @@ -151,9 +150,9 @@ def data_loader(i): x_batch, y_batch = data_loader(i) x_np = x_batch.cpu().numpy() y_np = y_batch.cpu().numpy() - true_res += np_mmd(x_np, y_np) + true_res += np_mmd2(x_np, y_np) - true_res /= n_iters + true_res = np.sqrt(true_res / n_iters) assert pytest.approx(true_res, abs=tol) == res def test_accumulator_device(self): @@ -164,12 +163,12 @@ def test_accumulator_device(self): for metric_device in metric_devices: mmd = MaximumMeanDiscrepancy(device=metric_device) - for dev in (mmd._device, mmd._sum_of_mmd.device): + for dev in (mmd._device, mmd._sum_of_mmd2.device): assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}" x = torch.tensor([[2.0, 3.0], [-2.0, 1.0]]).float() y = torch.ones(2, 2).float() mmd.update((x, y)) - for dev in (mmd._device, mmd._sum_of_mmd.device): + for dev in (mmd._device, mmd._sum_of_mmd2.device): assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}" From cc5d555157228e65c5ac367f3fce88d970317e99 Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Tue, 7 May 2024 22:05:39 +0900 Subject: [PATCH 07/10] accumulate sum of xx, yy, and xy --- ignite/metrics/maximum_mean_discrepancy.py | 13 ++++++++----- .../ignite/metrics/test_maximum_mean_discrepancy.py | 8 +++++--- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/ignite/metrics/maximum_mean_discrepancy.py b/ignite/metrics/maximum_mean_discrepancy.py index 5afb9e0461c9..9e8fe3e8e4a2 100644 --- a/ignite/metrics/maximum_mean_discrepancy.py +++ b/ignite/metrics/maximum_mean_discrepancy.py @@ -84,7 +84,9 @@ def __init__( @reinit__is_reduced def reset(self) -> None: - self._sum_of_mmd2 = torch.tensor(0.0, device=self._device) + self._xx_sum = torch.tensor(0.0, device=self._device) + self._yy_sum = torch.tensor(0.0, device=self._device) + self._xy_sum = torch.tensor(0.0, device=self._device) self._num_batches = 0 @reinit__is_reduced @@ -118,14 +120,15 @@ def update(self, output: Sequence[torch.Tensor]) -> None: YY = (YY.sum() - n) / (n * (n - 1)) XY = XY.sum() / (n * n) - # mmd cannot be negative - mmd2 = (XX - 2.0 * XY + YY).clamp(min=0.0) + self._xx_sum += XX.to(self._device) + self._yy_sum += YY.to(self._device) + self._xy_sum += XY.to(self._device) - self._sum_of_mmd2 += mmd2.to(self._device) self._num_batches += 1 @sync_all_reduce("_sum_of_mmd2", "_num_batches") def compute(self) -> float: if self._num_batches == 0: raise NotComputableError("MaximumMeanDiscrepacy must have at least one batch before it can be computed.") - return (self._sum_of_mmd2 / self._num_batches).sqrt().item() + mmd2 = (self._xx_sum + self._yy_sum - 2.0 * self._xy_sum).clamp(min=0.0) / self._num_batches + return mmd2.sqrt().item() diff --git a/tests/ignite/metrics/test_maximum_mean_discrepancy.py b/tests/ignite/metrics/test_maximum_mean_discrepancy.py index ed43439a4335..78fec15504a5 100644 --- a/tests/ignite/metrics/test_maximum_mean_discrepancy.py +++ b/tests/ignite/metrics/test_maximum_mean_discrepancy.py @@ -104,7 +104,7 @@ def test_accumulator_detached(): y = torch.tensor([[-2.0, 1.0], [2.0, 3.0]], dtype=torch.float) mmd.update((x, y)) - assert not mmd._sum_of_mmd2.requires_grad + assert not any(acc.requires_grad for acc in (mmd._xx_sum, mmd._yy_sum, mmd._xy_sum)) @pytest.mark.usefixtures("distributed") @@ -163,12 +163,14 @@ def test_accumulator_device(self): for metric_device in metric_devices: mmd = MaximumMeanDiscrepancy(device=metric_device) - for dev in (mmd._device, mmd._sum_of_mmd2.device): + devices = (mmd._device, mmd._xx_sum, mmd._yy_sum, mmd._xy_sum) + for dev in devices: assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}" x = torch.tensor([[2.0, 3.0], [-2.0, 1.0]]).float() y = torch.ones(2, 2).float() mmd.update((x, y)) - for dev in (mmd._device, mmd._sum_of_mmd2.device): + devices = (mmd._device, mmd._xx_sum, mmd._yy_sum, mmd._xy_sum) + for dev in devices: assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}" From 2884f70c121c0422bd36bf483e5d9b12406b68d0 Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Tue, 7 May 2024 23:00:34 +0900 Subject: [PATCH 08/10] add reference paper to docstring --- ignite/metrics/maximum_mean_discrepancy.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ignite/metrics/maximum_mean_discrepancy.py b/ignite/metrics/maximum_mean_discrepancy.py index 9e8fe3e8e4a2..e0a18362b36d 100644 --- a/ignite/metrics/maximum_mean_discrepancy.py +++ b/ignite/metrics/maximum_mean_discrepancy.py @@ -27,6 +27,10 @@ class MaximumMeanDiscrepancy(Metric): This metric computes the MMD for each batch and takes the average. + More details can be found in `Gretton et al. 2012`__. + + __ https://jmlr.csail.mit.edu/papers/v13/gretton12a.html + - ``update`` must receive output of the form ``(x, y)``. - ``x`` and ``y`` are expected to be in the same shape :math:`(B, \ldots)`. From 9f2ca1b1bbf9d9630393b551719cc60a7803769d Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Tue, 7 May 2024 23:06:04 +0900 Subject: [PATCH 09/10] fix accumulator variables --- ignite/metrics/maximum_mean_discrepancy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ignite/metrics/maximum_mean_discrepancy.py b/ignite/metrics/maximum_mean_discrepancy.py index e0a18362b36d..d92dd5448ce1 100644 --- a/ignite/metrics/maximum_mean_discrepancy.py +++ b/ignite/metrics/maximum_mean_discrepancy.py @@ -78,7 +78,7 @@ class MaximumMeanDiscrepancy(Metric): 1.0726975202560425 """ - _state_dict_all_req_keys = ("_sum_of_mmd2", "_num_batches") + _state_dict_all_req_keys = ("_xx_sum", "_yy_sum", "_xy_sum", "_num_batches") def __init__( self, var: float = 1.0, output_transform: Callable = lambda x: x, device: torch.device = torch.device("cpu") @@ -130,7 +130,7 @@ def update(self, output: Sequence[torch.Tensor]) -> None: self._num_batches += 1 - @sync_all_reduce("_sum_of_mmd2", "_num_batches") + @sync_all_reduce("_xx_sum", "_yy_sum", "_xy_sum", "_num_batches") def compute(self) -> float: if self._num_batches == 0: raise NotComputableError("MaximumMeanDiscrepacy must have at least one batch before it can be computed.") From 52773b19cf22013922c19459c4a743093e9b3840 Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Wed, 8 May 2024 00:00:56 +0900 Subject: [PATCH 10/10] fix test_accumulator_device --- tests/ignite/metrics/test_maximum_mean_discrepancy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/ignite/metrics/test_maximum_mean_discrepancy.py b/tests/ignite/metrics/test_maximum_mean_discrepancy.py index 78fec15504a5..8cfc5f55567d 100644 --- a/tests/ignite/metrics/test_maximum_mean_discrepancy.py +++ b/tests/ignite/metrics/test_maximum_mean_discrepancy.py @@ -163,7 +163,7 @@ def test_accumulator_device(self): for metric_device in metric_devices: mmd = MaximumMeanDiscrepancy(device=metric_device) - devices = (mmd._device, mmd._xx_sum, mmd._yy_sum, mmd._xy_sum) + devices = (mmd._device, mmd._xx_sum.device, mmd._yy_sum.device, mmd._xy_sum.device) for dev in devices: assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}" @@ -171,6 +171,6 @@ def test_accumulator_device(self): y = torch.ones(2, 2).float() mmd.update((x, y)) - devices = (mmd._device, mmd._xx_sum, mmd._yy_sum, mmd._xy_sum) + devices = (mmd._device, mmd._xx_sum.device, mmd._yy_sum.device, mmd._xy_sum.device) for dev in devices: assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}"