Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add KLDivergence Metric #247

Merged
merged 19 commits into from
Jun 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added Inception Score metric to image module ([#299](https://github.com/PyTorchLightning/metrics/pull/299))
- Added KID metric to image module ([#301](https://github.com/PyTorchLightning/metrics/pull/301))
- Added `sync` and `sync_context` methods for manually controlling when metric states are synced ([#302](https://github.com/PyTorchLightning/metrics/pull/302))
- Added `KLDivergence` metric ([#247](https://github.com/PyTorchLightning/metrics/pull/247))

### Changed

Expand Down
1 change: 1 addition & 0 deletions docs/source/pages/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ the following limitations:

- :ref:`references/modules:PSNR` and :ref:`references/functional:psnr [func]`
- :ref:`references/modules:SSIM` and :ref:`references/functional:ssim [func]`
- :ref:`references/modules:KLDivergence` and :ref:`references/functional:kldivergence [func]`

******************
Metric Arithmetics
Expand Down
6 changes: 6 additions & 0 deletions docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,12 @@ iou [func]
.. autofunction:: torchmetrics.functional.iou
:noindex:

kldivergence [func]
~~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.kldivergence
:noindex:

matthews_corrcoef [func]
~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
6 changes: 6 additions & 0 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,12 @@ IoU
.. autoclass:: torchmetrics.IoU
:noindex:

KLDivergence
~~~~~~~~~~~~

.. autoclass:: torchmetrics.KLDivergence
:noindex:

MatthewsCorrcoef
~~~~~~~~~~~~~~~~

Expand Down
115 changes: 115 additions & 0 deletions tests/classification/test_kldivergence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
from functools import partial
from typing import Optional

import numpy as np
import pytest
import torch
from scipy.stats import entropy
from torch import Tensor

from tests.helpers import seed_all
from tests.helpers.testers import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, MetricTester
from torchmetrics.classification import KLDivergence
from torchmetrics.functional import kldivergence

seed_all(42)

Input = namedtuple('Input', ["p", "q"])

_probs_inputs = Input(
p=torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM),
q=torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM),
)

_log_probs_inputs = Input(
p=torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM).softmax(dim=-1).log(),
q=torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM).softmax(dim=-1).log(),
)


def _sk_metric(p: Tensor, q: Tensor, log_prob: bool, reduction: Optional[str] = 'mean'):
if log_prob:
p = p.softmax(dim=-1)
q = q.softmax(dim=-1)
res = entropy(p, q, axis=1)
if reduction == 'mean':
return np.mean(res)
elif reduction == 'sum':
return np.sum(res)
else:
return res


@pytest.mark.parametrize("reduction", ['mean', 'sum'])
@pytest.mark.parametrize(
"p, q, log_prob", [(_probs_inputs.p, _probs_inputs.q, False), (_log_probs_inputs.p, _log_probs_inputs.q, True)]
)
class TestKLDivergence(MetricTester):
atol = 1e-6

@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_kldivergence(self, reduction, p, q, log_prob, ddp, dist_sync_on_step):
self.run_class_metric_test(
ddp,
p,
q,
KLDivergence,
partial(_sk_metric, log_prob=log_prob, reduction=reduction),
dist_sync_on_step,
metric_args=dict(log_prob=log_prob, reduction=reduction),
)

def test_kldivergence_functional(self, reduction, p, q, log_prob):
# todo: `num_outputs` is unused
self.run_functional_metric_test(
p,
q,
kldivergence,
partial(_sk_metric, log_prob=log_prob, reduction=reduction),
metric_args=dict(log_prob=log_prob, reduction=reduction),
)

def test_kldivergence_differentiabilit(self, reduction, p, q, log_prob):
self.run_differentiability_test(
p,
q,
metric_module=KLDivergence,
metric_functional=kldivergence,
metric_args=dict(log_prob=log_prob, reduction=reduction)
)

# KLDivergence half + cpu does not work due to missing support in torch.clamp
@pytest.mark.xfail(reason="KLDivergence metric does not support cpu + half precision")
def test_kldivergence_half_cpu(self, reduction, p, q, log_prob):
self.run_precision_test_cpu(p, q, KLDivergence, kldivergence, {'log_prob': log_prob, 'reduction': reduction})

@pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda')
def test_r2_half_gpu(self, reduction, p, q, log_prob):
self.run_precision_test_gpu(p, q, KLDivergence, kldivergence, {'log_prob': log_prob, 'reduction': reduction})


def test_error_on_different_shape():
metric = KLDivergence()
with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'):
metric(torch.randn(100, ), torch.randn(50, ))


def test_error_on_multidim_tensors():
metric = KLDivergence()
with pytest.raises(ValueError, match='Expected both p and q distribution to be 2D but got 3 and 3 respectively'):
metric(torch.randn(10, 20, 5), torch.randn(10, 20, 5))
1 change: 1 addition & 0 deletions torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
HammingDistance,
Hinge,
IoU,
KLDivergence,
MatthewsCorrcoef,
Precision,
PrecisionRecallCurve,
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from torchmetrics.classification.hamming_distance import HammingDistance # noqa: F401
from torchmetrics.classification.hinge import Hinge # noqa: F401
from torchmetrics.classification.iou import IoU # noqa: F401
from torchmetrics.classification.kldivergence import KLDivergence # noqa: F401
from torchmetrics.classification.matthews_corrcoef import MatthewsCorrcoef # noqa: F401
from torchmetrics.classification.precision_recall import Precision, Recall # noqa: F401
from torchmetrics.classification.precision_recall_curve import PrecisionRecallCurve # noqa: F401
Expand Down
108 changes: 108 additions & 0 deletions torchmetrics/classification/kldivergence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional

import torch
from torch import Tensor

from torchmetrics.functional.classification.kldivergence import _kld_compute, _kld_update
from torchmetrics.metric import Metric
from torchmetrics.utilities.data import dim_zero_cat


class KLDivergence(Metric):
r"""Computes the `KL divergence <https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence>`_:

.. math::
D_{KL}(P||Q) = \sum_{x\in\mathcal{X}} P(x) \log\frac{P(x)}{Q{x}}

Where :math:`P` and :math:`Q` are probability distributions where :math:`P` usually represents a distribution
over data and :math:`Q` is often a prior or approximation of :math:`P`. It should be noted that the KL divergence
is a non-symetrical metric i.e. :math:`D_{KL}(P||Q) \neq D_{KL}(Q||P)`.

Args:
p: data distribution with shape ``[N, d]``
q: prior or approximate distribution with shape ``[N, d]``
log_prob: bool indicating if input is log-probabilities or probabilities. If given as probabilities,
will normalize to make sure the distributes sum to 1
reduction:
Determines how to reduce over the ``N``/batch dimension:

- ``'mean'`` [default]: Averages score across samples
- ``'sum'``: Sum score across samples
- ``'none'`` or ``None``: Returns score per sample

Raises:
TypeError:
If ``log_prob`` is not an ``bool``
ValueError:
If ``reduction`` is not one of ``'mean'``, ``'sum'``, ``'none'`` or ``None``

.. note::
Half precision is only support on GPU for this metric

Example:
>>> import torch
>>> from torchmetrics.functional import kldivergence
>>> p = torch.tensor([[0.36, 0.48, 0.16]])
>>> q = torch.tensor([[1/3, 1/3, 1/3]])
>>> kldivergence(p, q)
tensor(0.0853)
"""

def __init__(
self,
log_prob: bool = False,
reduction: Optional[str] = 'mean',
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
):
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
)
if not isinstance(log_prob, bool):
raise TypeError(f'Expected argument `log_prob` to be bool but got {log_prob}')
self.log_prob = log_prob

allowed_reduction = ['mean', 'sum', 'none', None]
if reduction not in allowed_reduction:
raise ValueError(f"Expected argument `reduction` to be one of {allowed_reduction} but got {reduction}")
self.reduction = reduction

if self.reduction in ['mean', 'sum']:
self.add_state('measures', torch.zeros(1), dist_reduce_fx='sum')
else:
self.add_state('measures', [], dist_reduce_fx='cat')
self.add_state('total', torch.zeros(1), dist_reduce_fx='sum')

def update(self, p: Tensor, q: Tensor) -> None: # type: ignore
measures, total = _kld_update(p, q, self.log_prob)
if self.reduction is None or self.reduction == 'none':
self.measures.append(measures)
else:
self.measures += measures.sum()
self.total += total

def compute(self) -> Tensor:
measures = dim_zero_cat(self.measures) if self.reduction is None or self.reduction == 'none' else self.measures
return _kld_compute(measures, self.total, self.reduction)

@property
def is_differentiable(self) -> bool:
return True
1 change: 1 addition & 0 deletions torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from torchmetrics.functional.classification.hamming_distance import hamming_distance # noqa: F401
from torchmetrics.functional.classification.hinge import hinge # noqa: F401
from torchmetrics.functional.classification.iou import iou # noqa: F401
from torchmetrics.functional.classification.kldivergence import kldivergence # noqa: F401
from torchmetrics.functional.classification.matthews_corrcoef import matthews_corrcoef # noqa: F401
from torchmetrics.functional.classification.precision_recall import precision, precision_recall, recall # noqa: F401
from torchmetrics.functional.classification.precision_recall_curve import precision_recall_curve # noqa: F401
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/functional/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from torchmetrics.functional.classification.hamming_distance import hamming_distance # noqa: F401
from torchmetrics.functional.classification.hinge import hinge # noqa: F401
from torchmetrics.functional.classification.iou import iou # noqa: F401
from torchmetrics.functional.classification.kldivergence import kldivergence # noqa: F401
from torchmetrics.functional.classification.matthews_corrcoef import matthews_corrcoef # noqa: F401
from torchmetrics.functional.classification.precision_recall import precision, precision_recall, recall # noqa: F401
from torchmetrics.functional.classification.precision_recall_curve import precision_recall_curve # noqa: F401
Expand Down
82 changes: 82 additions & 0 deletions torchmetrics/functional/classification/kldivergence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
from typing import Optional, Tuple

import torch
from torch import Tensor

from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.data import METRIC_EPS


def _kld_update(p: Tensor, q: Tensor, log_prob: bool) -> Tuple[Tensor, int]:
_check_same_shape(p, q)
if p.ndim != 2 or q.ndim != 2:
raise ValueError(f"Expected both p and q distribution to be 2D but got {p.ndim} and {q.ndim} respectively")

total = p.shape[0]
if log_prob:
measures = torch.sum(p.exp() * (p - q), axis=-1)
else:
p = p / p.sum(axis=-1, keepdim=True)
q = q / q.sum(axis=-1, keepdim=True)
q = torch.clamp(q, METRIC_EPS)
measures = torch.sum(p * torch.log(p / q), axis=-1)

return measures, total


def _kld_compute(measures: Tensor, total: Tensor, reduction: Optional[str] = 'mean') -> Tensor:
if reduction == 'sum':
return measures.sum()
elif reduction == 'mean':
return measures.sum() / total
elif reduction is None or reduction == 'none':
return measures
return measures / total


def kldivergence(p: Tensor, q: Tensor, log_prob: bool = False, reduction: Optional[str] = 'mean') -> Tensor:
r"""Computes the `KL divergence <https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence>`_:

.. math::
D_{KL}(P||Q) = \sum_{x\in\mathcal{X}} P(x) \log\frac{P(x)}{Q{x}}

Where :math:`P` and :math:`Q` are probability distributions where :math:`P` usually represents a distribution
over data and :math:`Q` is often a prior or approximation of :math:`P`. It should be noted that the KL divergence
is a non-symetrical metric i.e. :math:`D_{KL}(P||Q) \neq D_{KL}(Q||P)`.

Args:
p: data distribution with shape ``[N, d]``
q: prior or approximate distribution with shape ``[N, d]``
log_prob: bool indicating if input is log-probabilities or probabilities. If given as probabilities,
will normalize to make sure the distributes sum to 1
reduction:
Determines how to reduce over the ``N``/batch dimension:

- ``'mean'`` [default]: Averages score across samples
- ``'sum'``: Sum score across samples
- ``'none'`` or ``None``: Returns score per sample

Example:
>>> import torch
>>> from torchmetrics.functional import kldivergence
>>> p = torch.tensor([[0.36, 0.48, 0.16]])
>>> q = torch.tensor([[1/3, 1/3, 1/3]])
>>> kldivergence(p, q)
tensor(0.0853)
"""
measures, total = _kld_update(p, q, log_prob)
return _kld_compute(measures, total, reduction)