-
-
Notifications
You must be signed in to change notification settings - Fork 650
Add MaximumMeanDiscrepancy metric #3243
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
9ace0b4
add MaximumMeanDiscrepancy metric
kzkadc 59f0386
fix URL
kzkadc 8f82909
Merge branch 'master' into mmd
kzkadc dfc495d
update formula
kzkadc c5eda6e
modify test for MMD
kzkadc c3346c5
set default var value for np_mmd
kzkadc 32ad1db
accumulate mmd2
kzkadc cc5d555
accumulate sum of xx, yy, and xy
kzkadc 2884f70
add reference paper to docstring
kzkadc 9f2ca1b
fix accumulator variables
kzkadc 52773b1
fix test_accumulator_device
kzkadc File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
from typing import Callable, Sequence | ||
|
||
import torch | ||
|
||
from ignite.exceptions import NotComputableError | ||
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce | ||
|
||
__all__ = ["MaximumMeanDiscrepancy"] | ||
|
||
|
||
class MaximumMeanDiscrepancy(Metric): | ||
r"""Calculates the mean of `maximum mean discrepancy (MMD) | ||
<https://www.onurtunali.com/ml/2019/03/08/maximum-mean-discrepancy-in-machine-learning.html>`_. | ||
|
||
.. math:: | ||
\begin{align*} | ||
\text{MMD}^2 (P,Q) &= \underset{\| f \| \leq 1}{\text{sup}} | \mathbb{E}_{X\sim P}[f(X)] | ||
- \mathbb{E}_{Y\sim Q}[f(Y)] |^2 \\ | ||
&\approx \frac{1}{B(B-1)} \sum_{i=1}^B \sum_{\substack{j=1 \\ j\neq i}}^B k(\mathbf{x}_i,\mathbf{x}_j) | ||
-\frac{2}{B^2}\sum_{i=1}^B \sum_{j=1}^B k(\mathbf{x}_i,\mathbf{y}_j) | ||
+ \frac{1}{B(B-1)} \sum_{i=1}^B \sum_{\substack{j=1 \\ j\neq i}}^B k(\mathbf{y}_i,\mathbf{y}_j) | ||
\end{align*} | ||
|
||
where :math:`B` is the batch size, and :math:`\mathbf{x}_i` and :math:`\mathbf{y}_j` are | ||
vfdev-5 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
feature vectors sampled from :math:`P` and :math:`Q`, respectively. | ||
:math:`k(\mathbf{x},\mathbf{y})=\exp(-\| \mathbf{x}-\mathbf{y} \|^2/ 2\sigma^2)` is the Gaussian RBF kernel. | ||
|
||
This metric computes the MMD for each batch and takes the average. | ||
|
||
More details can be found in `Gretton et al. 2012`__. | ||
|
||
__ https://jmlr.csail.mit.edu/papers/v13/gretton12a.html | ||
|
||
- ``update`` must receive output of the form ``(x, y)``. | ||
- ``x`` and ``y`` are expected to be in the same shape :math:`(B, \ldots)`. | ||
|
||
Args: | ||
var: the bandwidth :math:`\sigma^2` of the kernel. Default: 1.0 | ||
output_transform: a callable that is used to transform the | ||
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the | ||
form expected by the metric. This can be useful if, for example, you have a multi-output model and | ||
you want to compute the metric with respect to one of the outputs. | ||
By default, this metric requires the output as ``(x, y)``. | ||
device: specifies which device updates are accumulated on. Setting the | ||
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is | ||
non-blocking. By default, CPU. | ||
|
||
Examples: | ||
To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine. | ||
The output of the engine's ``process_function`` needs to be in the format of | ||
``(x, y)``. If not, ``output_tranform`` can be added | ||
to the metric to transform the output into the form expected by the metric. | ||
|
||
For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`. | ||
|
||
.. include:: defaults.rst | ||
:start-after: :orphan: | ||
|
||
.. testcode:: | ||
|
||
metric = MaximumMeanDiscrepancy() | ||
metric.attach(default_evaluator, "mmd") | ||
x = torch.tensor([[-0.80324818, -0.95768364, -0.03807209], | ||
[-0.11059691, -0.38230813, -0.4111988], | ||
[-0.8864329, -0.02890403, -0.60119252], | ||
[-0.68732452, -0.12854739, -0.72095073], | ||
[-0.62604613, -0.52368328, -0.24112842]]) | ||
y = torch.tensor([[0.0686768, 0.80502737, 0.53321717], | ||
[0.83849465, 0.59099726, 0.76385441], | ||
[0.68688272, 0.56833803, 0.98100778], | ||
[0.55267761, 0.13084654, 0.45382906], | ||
[0.0754253, 0.70317304, 0.4756805]]) | ||
state = default_evaluator.run([[x, y]]) | ||
print(state.metrics["mmd"]) | ||
|
||
.. testoutput:: | ||
|
||
1.0726975202560425 | ||
""" | ||
|
||
_state_dict_all_req_keys = ("_xx_sum", "_yy_sum", "_xy_sum", "_num_batches") | ||
|
||
def __init__( | ||
self, var: float = 1.0, output_transform: Callable = lambda x: x, device: torch.device = torch.device("cpu") | ||
): | ||
self.var = var | ||
super().__init__(output_transform, device) | ||
|
||
@reinit__is_reduced | ||
def reset(self) -> None: | ||
self._xx_sum = torch.tensor(0.0, device=self._device) | ||
self._yy_sum = torch.tensor(0.0, device=self._device) | ||
self._xy_sum = torch.tensor(0.0, device=self._device) | ||
self._num_batches = 0 | ||
|
||
@reinit__is_reduced | ||
def update(self, output: Sequence[torch.Tensor]) -> None: | ||
x, y = output[0].detach(), output[1].detach() | ||
if x.shape != y.shape: | ||
raise ValueError(f"x and y must be in the same shape, got {x.shape} != {y.shape}.") | ||
|
||
if x.ndim >= 3: | ||
x = x.flatten(start_dim=1) | ||
y = y.flatten(start_dim=1) | ||
elif x.ndim == 1: | ||
raise ValueError(f"x must be in the shape of (B, ...), got {x.shape}.") | ||
|
||
xx, yy, zz = torch.mm(x, x.t()), torch.mm(y, y.t()), torch.mm(x, y.t()) | ||
rx = xx.diag().unsqueeze(0).expand_as(xx) | ||
ry = yy.diag().unsqueeze(0).expand_as(yy) | ||
|
||
dxx = rx.t() + rx - 2.0 * xx | ||
dyy = ry.t() + ry - 2.0 * yy | ||
dxy = rx.t() + ry - 2.0 * zz | ||
|
||
v = self.var | ||
XX = torch.exp(-0.5 * dxx / v) | ||
YY = torch.exp(-0.5 * dyy / v) | ||
XY = torch.exp(-0.5 * dxy / v) | ||
|
||
# unbiased | ||
n = x.shape[0] | ||
XX = (XX.sum() - n) / (n * (n - 1)) | ||
YY = (YY.sum() - n) / (n * (n - 1)) | ||
XY = XY.sum() / (n * n) | ||
vfdev-5 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
self._xx_sum += XX.to(self._device) | ||
self._yy_sum += YY.to(self._device) | ||
self._xy_sum += XY.to(self._device) | ||
|
||
self._num_batches += 1 | ||
kzkadc marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
@sync_all_reduce("_xx_sum", "_yy_sum", "_xy_sum", "_num_batches") | ||
def compute(self) -> float: | ||
if self._num_batches == 0: | ||
raise NotComputableError("MaximumMeanDiscrepacy must have at least one batch before it can be computed.") | ||
mmd2 = (self._xx_sum + self._yy_sum - 2.0 * self._xy_sum).clamp(min=0.0) / self._num_batches | ||
return mmd2.sqrt().item() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
from typing import Tuple | ||
|
||
import numpy as np | ||
import pytest | ||
import torch | ||
from torch import Tensor | ||
|
||
import ignite.distributed as idist | ||
from ignite.engine import Engine | ||
from ignite.exceptions import NotComputableError | ||
from ignite.metrics import MaximumMeanDiscrepancy | ||
|
||
|
||
def np_mmd2(x: np.ndarray, y: np.ndarray, var: float = 1.0): | ||
n = x.shape[0] | ||
x = x.reshape(n, -1) | ||
y = y.reshape(n, -1) | ||
|
||
a = np.arange(n) | ||
ii, jj = np.meshgrid(a, a, indexing="ij") | ||
XX = np.exp(-np.square(x[ii] - x[jj]).sum(axis=2) / (var * 2)) | ||
XX = (np.sum(XX) - n) / (n * (n - 1)) | ||
|
||
XY = np.exp(-np.square(x[ii] - y[jj]).sum(axis=2) / (var * 2)) | ||
XY = np.sum(XY) / (n * n) | ||
|
||
YY = np.exp(-np.square(y[ii] - y[jj]).sum(axis=2) / (var * 2)) | ||
YY = (np.sum(YY) - n) / (n * (n - 1)) | ||
|
||
mmd2 = np.clip(XX + YY - XY * 2, 0.0, None) | ||
return mmd2 | ||
|
||
|
||
def test_zero_sample(): | ||
mmd = MaximumMeanDiscrepancy() | ||
with pytest.raises( | ||
NotComputableError, match=r"MaximumMeanDiscrepacy must have at least one batch before it can be computed" | ||
): | ||
mmd.compute() | ||
|
||
|
||
def test_shape_mismatch(): | ||
mmd = MaximumMeanDiscrepancy() | ||
x = torch.tensor([[2.0, 3.0], [-2.0, 1.0]], dtype=torch.float) | ||
y = torch.tensor([[-2.0, 1.0]], dtype=torch.float) | ||
with pytest.raises(ValueError, match=r"x and y must be in the same shape, got"): | ||
mmd.update((x, y)) | ||
|
||
|
||
def test_invalid_shape(): | ||
mmd = MaximumMeanDiscrepancy() | ||
x = torch.tensor([2.0, 3.0], dtype=torch.float) | ||
y = torch.tensor([4.0, 5.0], dtype=torch.float) | ||
with pytest.raises(ValueError, match=r"x must be in the shape of \(B, ...\), got"): | ||
mmd.update((x, y)) | ||
|
||
|
||
@pytest.fixture(params=list(range(4))) | ||
def test_case(request): | ||
return [ | ||
(torch.randn((100, 10)), torch.rand((100, 10)), 10 ** np.random.uniform(-1.0, 0.0), 1), | ||
(torch.rand((100, 500)), torch.randn((100, 500)), 10 ** np.random.uniform(-1.0, 0.0), 1), | ||
# updated batches | ||
(torch.normal(0.0, 5.0, size=(100, 10)), torch.rand((100, 10)), 10 ** np.random.uniform(-1.0, 0.0), 16), | ||
(torch.normal(5.0, 3.0, size=(100, 200)), torch.rand((100, 200)), 10 ** np.random.uniform(-1.0, 0.0), 16), | ||
# image segmentation | ||
(torch.randn((100, 5, 32, 32)), torch.rand((100, 5, 32, 32)), 10 ** np.random.uniform(-1.0, 0.0), 32), | ||
(torch.rand((100, 5, 224, 224)), torch.randn((100, 5, 224, 224)), 10 ** np.random.uniform(-1.0, 0.0), 32), | ||
][request.param] | ||
|
||
|
||
@pytest.mark.parametrize("n_times", range(5)) | ||
def test_compute(n_times, test_case: Tuple[Tensor, Tensor, float, int]): | ||
x, y, var, batch_size = test_case | ||
|
||
mmd = MaximumMeanDiscrepancy(var=var) | ||
mmd.reset() | ||
|
||
if batch_size > 1: | ||
np_mmd2_sum = 0.0 | ||
n_iters = y.shape[0] // batch_size + 1 | ||
for i in range(n_iters): | ||
idx = i * batch_size | ||
x_batch, y_batch = x[idx : idx + batch_size], y[idx : idx + batch_size] | ||
mmd.update((x_batch, y_batch)) | ||
|
||
np_mmd2_sum += np_mmd2(x_batch.cpu().numpy(), y_batch.cpu().numpy(), var) | ||
|
||
np_res = np.sqrt(np_mmd2_sum / n_iters) | ||
else: | ||
mmd.update((x, y)) | ||
np_res = np.sqrt(np_mmd2(x.cpu().numpy(), y.cpu().numpy(), var)) | ||
|
||
res = mmd.compute() | ||
|
||
assert isinstance(res, float) | ||
assert pytest.approx(np_res, abs=1e-4) == res | ||
|
||
|
||
def test_accumulator_detached(): | ||
mmd = MaximumMeanDiscrepancy() | ||
|
||
x = torch.tensor([[2.0, 3.0], [-2.0, 1.0]], dtype=torch.float) | ||
y = torch.tensor([[-2.0, 1.0], [2.0, 3.0]], dtype=torch.float) | ||
mmd.update((x, y)) | ||
|
||
assert not any(acc.requires_grad for acc in (mmd._xx_sum, mmd._yy_sum, mmd._xy_sum)) | ||
|
||
|
||
@pytest.mark.usefixtures("distributed") | ||
class TestDistributed: | ||
def test_integration(self): | ||
tol = 1e-4 | ||
n_iters = 100 | ||
batch_size = 10 | ||
n_dims = 100 | ||
|
||
rank = idist.get_rank() | ||
torch.manual_seed(12 + rank) | ||
|
||
device = idist.device() | ||
metric_devices = [torch.device("cpu")] | ||
if device.type != "xla": | ||
metric_devices.append(device) | ||
|
||
for metric_device in metric_devices: | ||
y = torch.randn((n_iters * batch_size, n_dims)).float().to(device) | ||
x = torch.normal(2.0, 3.0, size=(n_iters * batch_size, n_dims)).float().to(device) | ||
|
||
def data_loader(i): | ||
return x[i * batch_size : (i + 1) * batch_size], y[i * batch_size : (i + 1) * batch_size] | ||
|
||
engine = Engine(lambda e, i: data_loader(i)) | ||
|
||
m = MaximumMeanDiscrepancy(device=metric_device) | ||
m.attach(engine, "mmd") | ||
|
||
data = list(range(n_iters)) | ||
engine.run(data=data, max_epochs=1) | ||
|
||
x = idist.all_gather(x) | ||
y = idist.all_gather(y) | ||
|
||
assert "mmd" in engine.state.metrics | ||
res = engine.state.metrics["mmd"] | ||
|
||
# compute numpy mmd | ||
true_res = 0.0 | ||
for i in range(n_iters): | ||
x_batch, y_batch = data_loader(i) | ||
x_np = x_batch.cpu().numpy() | ||
y_np = y_batch.cpu().numpy() | ||
true_res += np_mmd2(x_np, y_np) | ||
|
||
true_res = np.sqrt(true_res / n_iters) | ||
assert pytest.approx(true_res, abs=tol) == res | ||
|
||
def test_accumulator_device(self): | ||
device = idist.device() | ||
metric_devices = [torch.device("cpu")] | ||
if device.type != "xla": | ||
metric_devices.append(device) | ||
for metric_device in metric_devices: | ||
mmd = MaximumMeanDiscrepancy(device=metric_device) | ||
|
||
devices = (mmd._device, mmd._xx_sum.device, mmd._yy_sum.device, mmd._xy_sum.device) | ||
for dev in devices: | ||
assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}" | ||
|
||
x = torch.tensor([[2.0, 3.0], [-2.0, 1.0]]).float() | ||
y = torch.ones(2, 2).float() | ||
mmd.update((x, y)) | ||
|
||
devices = (mmd._device, mmd._xx_sum.device, mmd._yy_sum.device, mmd._xy_sum.device) | ||
for dev in devices: | ||
assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}" |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.