Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added MinkowskiDistance support #1362

Merged
merged 66 commits into from
Feb 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
3c01e16
Added MinkowskiDistance support
clueless-skywatcher Nov 26, 2022
b98bb8d
Added class docstring
clueless-skywatcher Nov 26, 2022
c7f4993
Added function docstring; Raw-stringified class docstring
clueless-skywatcher Nov 26, 2022
28fac1b
Apply suggestions from code review
Borda Nov 28, 2022
6a9b4ff
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 28, 2022
6a75cdf
Addressed comments; Added new testcase for unequal size tensors; Adde…
clueless-skywatcher Nov 28, 2022
25436a9
Merge branch 'master' into master
Borda Dec 7, 2022
3300ee5
Merge branch 'master' into master
SkafteNicki Dec 9, 2022
ad0dc55
Merge branch 'master' into master
Borda Dec 14, 2022
1dcaa3d
Merge branch 'master' into master
Borda Dec 23, 2022
c8ebe00
changelog
SkafteNicki Jan 31, 2023
e525688
Merge branch 'master' into master
SkafteNicki Jan 31, 2023
505c76b
add doc page
SkafteNicki Jan 31, 2023
c947509
fix tests and impl
SkafteNicki Jan 31, 2023
cd6c766
Merge branch 'master' of https://github.com/clueless-skywatcher/metri…
SkafteNicki Jan 31, 2023
c242494
Merge branch 'master' into master
SkafteNicki Jan 31, 2023
62d95d6
fix functional doctest
SkafteNicki Jan 31, 2023
52b665e
fix tests
SkafteNicki Feb 1, 2023
1781ec1
add pairwise version
SkafteNicki Feb 1, 2023
0c50ebd
changelog
SkafteNicki Feb 1, 2023
7cdecd5
Merge branch 'master' into master
SkafteNicki Feb 1, 2023
f29a214
fix doctest
SkafteNicki Feb 1, 2023
51ffab4
Merge branch 'master' of https://github.com/clueless-skywatcher/metri…
SkafteNicki Feb 1, 2023
d833eaf
Merge branch 'master' into master
SkafteNicki Feb 3, 2023
aaa580e
fix old cpu half tests
SkafteNicki Feb 3, 2023
08b9595
Merge branch 'master' into master
SkafteNicki Feb 3, 2023
250c9a7
fix tests
SkafteNicki Feb 3, 2023
a24fdae
add link + update condition
SkafteNicki Feb 3, 2023
307cf42
Merge branch 'master' into master
Borda Feb 6, 2023
237c0ed
Merge branch 'master' into master
mergify[bot] Feb 6, 2023
c8b6892
Merge branch 'master' into master
mergify[bot] Feb 7, 2023
bb49eee
Merge branch 'master' into master
mergify[bot] Feb 7, 2023
3bc0d4c
Merge branch 'master' into master
mergify[bot] Feb 7, 2023
0775df6
Merge branch 'master' into master
Borda Feb 18, 2023
9ba5f21
Merge branch 'master' into master
mergify[bot] Feb 20, 2023
e7dd814
Merge branch 'master' into master
Borda Feb 20, 2023
1afcfcb
Merge branch 'master' into master
mergify[bot] Feb 20, 2023
64c3639
Merge branch 'master' into master
mergify[bot] Feb 20, 2023
89b9ba3
Merge branch 'master' into master
mergify[bot] Feb 20, 2023
6732e8e
Merge branch 'master' into master
Borda Feb 21, 2023
1986ec8
exponent
Borda Feb 21, 2023
478b844
rev
Borda Feb 21, 2023
b73ccc5
tests
Borda Feb 21, 2023
cbe535f
Merge branch 'master' into master
Borda Feb 22, 2023
5ab9a33
Merge branch 'master' into master
Borda Feb 22, 2023
ec17fbe
Merge branch 'master' into master
mergify[bot] Feb 22, 2023
796acd0
Merge branch 'master' into master
mergify[bot] Feb 22, 2023
ec9c9f6
Merge branch 'master' into master
mergify[bot] Feb 23, 2023
5e9a279
Merge branch 'master' into master
mergify[bot] Feb 23, 2023
c6cea6c
Merge branch 'master' into master
mergify[bot] Feb 23, 2023
e221cbf
Merge branch 'master' into master
mergify[bot] Feb 23, 2023
8e05f98
Merge branch 'master' into master
mergify[bot] Feb 23, 2023
b820684
Merge branch 'master' into master
mergify[bot] Feb 24, 2023
c1fa4c6
fix tests
SkafteNicki Feb 24, 2023
4a6d226
Merge branch 'master' into master
SkafteNicki Feb 24, 2023
7944620
Update tests/unittests/pairwise/test_pairwise_distance.py
SkafteNicki Feb 24, 2023
c206baa
Merge branch 'master' into master
mergify[bot] Feb 24, 2023
49f8438
Merge branch 'master' into master
mergify[bot] Feb 24, 2023
2583252
Merge branch 'master' into master
mergify[bot] Feb 24, 2023
eb9af2d
ruff: first line split + imperative mood (#1548
SkafteNicki Feb 24, 2023
e027831
Merge branch 'master' into master
mergify[bot] Feb 24, 2023
53566a9
ruff: first line split + imperative mood (#1548)
SkafteNicki Feb 24, 2023
d02807a
Merge branch 'master' into master
mergify[bot] Feb 24, 2023
12b3a46
Merge branch 'master' into master
mergify[bot] Feb 24, 2023
764dcb2
fix broken imports
SkafteNicki Feb 24, 2023
fd9053d
ruff
SkafteNicki Feb 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `classes` to output from `MAP` metric ([#1419](https://github.com/Lightning-AI/metrics/pull/1419))


- Added `MinkowskiDistance` to regression package ([#1362](https://github.com/Lightning-AI/metrics/pull/1362))


- Added `pairwise_minkowski_distance` to pairwise package ([#1362](https://github.com/Lightning-AI/metrics/pull/1362))


- Added new detection metric `PanopticQuality` ([#929](https://github.com/PyTorchLightning/metrics/pull/929))


Expand Down
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,4 @@
.. _Multilabel coverage error: https://link.springer.com/chapter/10.1007/978-0-387-09823-4_34
.. _Panoptic Quality: https://arxiv.org/abs/1801.00868
.. _torchmetrics mAP example: https://github.com/Lightning-AI/metrics/blob/master/examples/detection_map.py
.. _Minkowski Distance: https://en.wikipedia.org/wiki/Minkowski_distance
14 changes: 14 additions & 0 deletions docs/source/pairwise/minkowski_distance.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
.. customcarditem::
:header: Pairwise Minkowski Distance
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/translation.svg
:tags: Pairwise

##################
Minkowski Distance
##################

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.pairwise_minkowski_distance
:noindex:
23 changes: 23 additions & 0 deletions docs/source/regression/minkowski_distance.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
.. customcarditem::
:header: Minkowski Distance
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg
:tags: Regression

.. include:: ../links.rst

##################
Minkowski Distance
##################

Module Interface
________________

.. autoclass:: torchmetrics.MinkowskiDistance
:noindex:
:exclude-members: update, compute

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.minkowski_distance
:noindex:
1 change: 1 addition & 0 deletions src/torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
MeanAbsolutePercentageError,
MeanSquaredError,
MeanSquaredLogError,
MinkowskiDistance,
PearsonCorrCoef,
R2Score,
SpearmanCorrCoef,
Expand Down
4 changes: 4 additions & 0 deletions src/torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from torchmetrics.functional.pairwise.euclidean import pairwise_euclidean_distance
from torchmetrics.functional.pairwise.linear import pairwise_linear_similarity
from torchmetrics.functional.pairwise.manhattan import pairwise_manhattan_distance
from torchmetrics.functional.pairwise.minkowski import pairwise_minkowski_distance
from torchmetrics.functional.regression.concordance import concordance_corrcoef
from torchmetrics.functional.regression.cosine_similarity import cosine_similarity
from torchmetrics.functional.regression.explained_variance import explained_variance
Expand All @@ -64,6 +65,7 @@
from torchmetrics.functional.regression.log_mse import mean_squared_log_error
from torchmetrics.functional.regression.mae import mean_absolute_error
from torchmetrics.functional.regression.mape import mean_absolute_percentage_error
from torchmetrics.functional.regression.minkowski import minkowski_distance
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
from torchmetrics.functional.regression.mse import mean_squared_error
from torchmetrics.functional.regression.pearson import pearson_corrcoef
from torchmetrics.functional.regression.r2 import r2_score
Expand Down Expand Up @@ -134,11 +136,13 @@
"mean_absolute_percentage_error",
"mean_squared_error",
"mean_squared_log_error",
"minkowski_distance",
"multiscale_structural_similarity_index_measure",
"pairwise_cosine_similarity",
"pairwise_euclidean_distance",
"pairwise_linear_similarity",
"pairwise_manhattan_distance",
"pairwise_minkowski_distance",
"panoptic_quality",
"pearson_corrcoef",
"pearsons_contingency_coefficient",
Expand Down
1 change: 1 addition & 0 deletions src/torchmetrics/functional/pairwise/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@
from torchmetrics.functional.pairwise.euclidean import pairwise_euclidean_distance # noqa: F401
from torchmetrics.functional.pairwise.linear import pairwise_linear_similarity # noqa: F401
from torchmetrics.functional.pairwise.manhattan import pairwise_manhattan_distance # noqa: F401
from torchmetrics.functional.pairwise.minkowski import pairwise_minkowski_distance # noqa: F401
91 changes: 91 additions & 0 deletions src/torchmetrics/functional/pairwise/minkowski.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Union

import torch
from torch import Tensor
from typing_extensions import Literal

from torchmetrics.functional.pairwise.helpers import _check_input, _reduce_distance_matrix
from torchmetrics.utilities.exceptions import TorchMetricsUserError


def _pairwise_minkowski_distance_update(
x: Tensor, y: Optional[Tensor] = None, exponent: Union[int, float] = 2, zero_diagonal: Optional[bool] = None
) -> Tensor:
"""Calculate the pairwise minkowski distance matrix.

Args:
x: tensor of shape ``[N,d]``
y: tensor of shape ``[M,d]``
exponent: int or float larger than 1, exponent to which the difference between preds and target is to be raised
zero_diagonal: determines if the diagonal of the distance matrix should be set to zero
"""
x, y, zero_diagonal = _check_input(x, y, zero_diagonal)
if not (isinstance(exponent, (float, int)) and exponent >= 1):
raise TorchMetricsUserError(f"Argument ``p`` must be a float or int greater than 1, but got {exponent}")
# upcast to float64 to prevent precision issues
_orig_dtype = x.dtype
x = x.to(torch.float64)
y = y.to(torch.float64)
Borda marked this conversation as resolved.
Show resolved Hide resolved
distance = (x.unsqueeze(1) - y.unsqueeze(0)).abs().pow(exponent).sum(-1).pow(1.0 / exponent)
if zero_diagonal:
distance.fill_diagonal_(0)
return distance.to(_orig_dtype)


def pairwise_minkowski_distance(
x: Tensor,
y: Optional[Tensor] = None,
exponent: Union[int, float] = 2,
reduction: Literal["mean", "sum", "none", None] = None,
zero_diagonal: Optional[bool] = None,
) -> Tensor:
r"""Calculate pairwise minkowski distances.

.. math::
d_{minkowski}(x,y,p) = ||x - y||_p = \sqrt[p]{\sum_{d=1}^D (x_d - y_d)^p}

If both :math:`x` and :math:`y` are passed in, the calculation will be performed pairwise between the rows of
:math:`x` and :math:`y`. If only :math:`x` is passed in, the calculation will be performed between the rows
of :math:`x`.

Args:
x: Tensor with shape ``[N, d]``
y: Tensor with shape ``[M, d]``, optional
exponent: int or float larger than 1, exponent to which the difference between preds and target is to be raised
reduction: reduction to apply along the last dimension. Choose between `'mean'`, `'sum'`
(applied along column dimension) or `'none'`, `None` for no reduction
zero_diagonal: if the diagonal of the distance matrix should be set to 0. If only `x` is given
this defaults to `True` else if `y` is also given it defaults to `False`

Returns:
A ``[N,N]`` matrix of distances if only ``x`` is given, else a ``[N,M]`` matrix

Example:
>>> import torch
>>> from torchmetrics.functional import pairwise_minkowski_distance
>>> x = torch.tensor([[2, 3], [3, 5], [5, 8]], dtype=torch.float32)
>>> y = torch.tensor([[1, 0], [2, 1]], dtype=torch.float32)
>>> pairwise_minkowski_distance(x, y, exponent=4)
tensor([[3.0092, 2.0000],
[5.0317, 4.0039],
[8.1222, 7.0583]])
>>> pairwise_minkowski_distance(x, exponent=4)
tensor([[0.0000, 2.0305, 5.1547],
[2.0305, 0.0000, 3.1383],
[5.1547, 3.1383, 0.0000]])
"""
distance = _pairwise_minkowski_distance_update(x, y, exponent, zero_diagonal)
return _reduce_distance_matrix(distance, reduction)
1 change: 1 addition & 0 deletions src/torchmetrics/functional/regression/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torchmetrics.functional.regression.log_mse import mean_squared_log_error # noqa: F401
from torchmetrics.functional.regression.mae import mean_absolute_error # noqa: F401
from torchmetrics.functional.regression.mape import mean_absolute_percentage_error # noqa: F401
from torchmetrics.functional.regression.minkowski import minkowski_distance # noqa: F401
from torchmetrics.functional.regression.mse import mean_squared_error # noqa: F401
from torchmetrics.functional.regression.pearson import pearson_corrcoef # noqa: F401
from torchmetrics.functional.regression.r2 import r2_score # noqa: F401
Expand Down
83 changes: 83 additions & 0 deletions src/torchmetrics/functional/regression/minkowski.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import Tensor

from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.exceptions import TorchMetricsUserError


def _minkowski_distance_update(preds: Tensor, targets: Tensor, p: float) -> Tensor:
"""Update and return variables required to compute Minkowski distance.

Checks for same shape of input tensors.

Args:
preds: Predicted tensor
targets: Ground truth tensor
p: Non-negative number acting as the p to the errors
"""
_check_same_shape(preds, targets)

if not (isinstance(p, (float, int)) and p >= 1):
raise TorchMetricsUserError(f"Argument ``p`` must be a float or int greater than 1, but got {p}")

difference = torch.abs(preds - targets)
mink_dist_sum = torch.sum(torch.pow(difference, p))

return mink_dist_sum


def _minkowski_distance_compute(distance: Tensor, p: float) -> Tensor:
"""Compute Minkowski Distance.

Args:
distance: Sum of the p-th powers of errors over all observations
p: The non-negative numeric power the errors are to be raised to

Example:
>>> preds = torch.tensor([0., 1, 2, 3])
>>> target = torch.tensor([0., 2, 3, 1])
>>> distance_p_sum = _minkowski_distance_update(preds, target, 5)
>>> _minkowski_distance_compute(distance_p_sum, 5)
tensor(2.0244)
"""
return torch.pow(distance, 1.0 / p)


def minkowski_distance(preds: Tensor, targets: Tensor, p: float) -> Tensor:
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
r"""Compute the `Minkowski distance`_.

.. math:: d_{\text{Minkowski}} = \\sum_{i}^N (| y_i - \\hat{y_i} |^p)^\frac{1}{p}

This metric can be seen as generalized version of the standard euclidean distance which corresponds to minkowski
distance with p=2.

Args:
preds: estimated labels of type Tensor
targets: ground truth labels of type Tensor
p: int or float larger than 1, exponent to which the difference between preds and target is to be raised

Return:
Tensor with the Minkowski distance

Example:
>>> from torchmetrics.functional import minkowski_distance
>>> x = torch.tensor([1.0, 2.8, 3.5, 4.5])
>>> y = torch.tensor([6.1, 2.11, 3.1, 5.6])
>>> minkowski_distance(x, y, p=3)
tensor(5.1220)
"""
minkowski_dist_sum = _minkowski_distance_update(preds, targets, p)
return _minkowski_distance_compute(minkowski_dist_sum, p)
1 change: 1 addition & 0 deletions src/torchmetrics/regression/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torchmetrics.regression.log_mse import MeanSquaredLogError # noqa: F401
from torchmetrics.regression.mae import MeanAbsoluteError # noqa: F401
from torchmetrics.regression.mape import MeanAbsolutePercentageError # noqa: F401
from torchmetrics.regression.minkowski import MinkowskiDistance # noqa: F401
from torchmetrics.regression.mse import MeanSquaredError # noqa: F401
from torchmetrics.regression.pearson import PearsonCorrCoef # noqa: F401
from torchmetrics.regression.r2 import R2Score # noqa: F401
Expand Down
70 changes: 70 additions & 0 deletions src/torchmetrics/regression/minkowski.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Optional

from torch import Tensor, tensor

from torchmetrics.functional.regression.minkowski import _minkowski_distance_compute, _minkowski_distance_update
from torchmetrics.metric import Metric
from torchmetrics.utilities.exceptions import TorchMetricsUserError


class MinkowskiDistance(Metric):
r"""Compute `Minkowski Distance`_.

.. math:: d_{\text{Minkowski}} = \\sum_{i}^N (| y_i - \\hat{y_i} |^p)^\frac{1}{p}

This metric can be seen as generalized version of the standard euclidean distance which corresponds to minkowski
distance with p=2.

where
:math:`y` is a tensor of target values,
:math:`\\hat{y}` is a tensor of predictions,
:math: `\\p` is a non-negative integer or floating-point number

Args:
p: int or float larger than 1, exponent to which the difference between preds and target is to be raised
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Example:
>>> from torchmetrics import MinkowskiDistance
>>> target = tensor([1.0, 2.8, 3.5, 4.5])
>>> preds = tensor([6.1, 2.11, 3.1, 5.6])
>>> minkowski_distance = MinkowskiDistance(3)
>>> minkowski_distance(preds, target)
tensor(5.1220)
"""

is_differentiable: Optional[bool] = True
Borda marked this conversation as resolved.
Show resolved Hide resolved
higher_is_better: Optional[bool] = False
full_state_update: Optional[bool] = False
minkowski_dist_sum: Tensor

def __init__(self, p: float, **kwargs: Any) -> None:
super().__init__(**kwargs)
if not (isinstance(p, (float, int)) and p >= 1):
raise TorchMetricsUserError(f"Argument ``p`` must be a float or int greater than 1, but got {p}")

self.p = p
self.add_state("minkowski_dist_sum", default=tensor(0.0), dist_reduce_fx="sum")

def update(self, preds: Tensor, targets: Tensor) -> None:
"""Update state with predictions and targets."""
minkowski_dist_sum = _minkowski_distance_update(preds, targets, self.p)
self.minkowski_dist_sum += minkowski_dist_sum

def compute(self) -> Tensor:
"""Compute metric."""
return _minkowski_distance_compute(self.minkowski_dist_sum, self.p)
Loading