Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added MinkowskiDistance support #1362

Merged
merged 66 commits into from
Feb 24, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
3c01e16
Added MinkowskiDistance support
clueless-skywatcher Nov 26, 2022
b98bb8d
Added class docstring
clueless-skywatcher Nov 26, 2022
c7f4993
Added function docstring; Raw-stringified class docstring
clueless-skywatcher Nov 26, 2022
28fac1b
Apply suggestions from code review
Borda Nov 28, 2022
6a9b4ff
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 28, 2022
6a75cdf
Addressed comments; Added new testcase for unequal size tensors; Adde…
clueless-skywatcher Nov 28, 2022
25436a9
Merge branch 'master' into master
Borda Dec 7, 2022
3300ee5
Merge branch 'master' into master
SkafteNicki Dec 9, 2022
ad0dc55
Merge branch 'master' into master
Borda Dec 14, 2022
1dcaa3d
Merge branch 'master' into master
Borda Dec 23, 2022
c8ebe00
changelog
SkafteNicki Jan 31, 2023
e525688
Merge branch 'master' into master
SkafteNicki Jan 31, 2023
505c76b
add doc page
SkafteNicki Jan 31, 2023
c947509
fix tests and impl
SkafteNicki Jan 31, 2023
cd6c766
Merge branch 'master' of https://github.com/clueless-skywatcher/metri…
SkafteNicki Jan 31, 2023
c242494
Merge branch 'master' into master
SkafteNicki Jan 31, 2023
62d95d6
fix functional doctest
SkafteNicki Jan 31, 2023
52b665e
fix tests
SkafteNicki Feb 1, 2023
1781ec1
add pairwise version
SkafteNicki Feb 1, 2023
0c50ebd
changelog
SkafteNicki Feb 1, 2023
7cdecd5
Merge branch 'master' into master
SkafteNicki Feb 1, 2023
f29a214
fix doctest
SkafteNicki Feb 1, 2023
51ffab4
Merge branch 'master' of https://github.com/clueless-skywatcher/metri…
SkafteNicki Feb 1, 2023
d833eaf
Merge branch 'master' into master
SkafteNicki Feb 3, 2023
aaa580e
fix old cpu half tests
SkafteNicki Feb 3, 2023
08b9595
Merge branch 'master' into master
SkafteNicki Feb 3, 2023
250c9a7
fix tests
SkafteNicki Feb 3, 2023
a24fdae
add link + update condition
SkafteNicki Feb 3, 2023
307cf42
Merge branch 'master' into master
Borda Feb 6, 2023
237c0ed
Merge branch 'master' into master
mergify[bot] Feb 6, 2023
c8b6892
Merge branch 'master' into master
mergify[bot] Feb 7, 2023
bb49eee
Merge branch 'master' into master
mergify[bot] Feb 7, 2023
3bc0d4c
Merge branch 'master' into master
mergify[bot] Feb 7, 2023
0775df6
Merge branch 'master' into master
Borda Feb 18, 2023
9ba5f21
Merge branch 'master' into master
mergify[bot] Feb 20, 2023
e7dd814
Merge branch 'master' into master
Borda Feb 20, 2023
1afcfcb
Merge branch 'master' into master
mergify[bot] Feb 20, 2023
64c3639
Merge branch 'master' into master
mergify[bot] Feb 20, 2023
89b9ba3
Merge branch 'master' into master
mergify[bot] Feb 20, 2023
6732e8e
Merge branch 'master' into master
Borda Feb 21, 2023
1986ec8
exponent
Borda Feb 21, 2023
478b844
rev
Borda Feb 21, 2023
b73ccc5
tests
Borda Feb 21, 2023
cbe535f
Merge branch 'master' into master
Borda Feb 22, 2023
5ab9a33
Merge branch 'master' into master
Borda Feb 22, 2023
ec17fbe
Merge branch 'master' into master
mergify[bot] Feb 22, 2023
796acd0
Merge branch 'master' into master
mergify[bot] Feb 22, 2023
ec9c9f6
Merge branch 'master' into master
mergify[bot] Feb 23, 2023
5e9a279
Merge branch 'master' into master
mergify[bot] Feb 23, 2023
c6cea6c
Merge branch 'master' into master
mergify[bot] Feb 23, 2023
e221cbf
Merge branch 'master' into master
mergify[bot] Feb 23, 2023
8e05f98
Merge branch 'master' into master
mergify[bot] Feb 23, 2023
b820684
Merge branch 'master' into master
mergify[bot] Feb 24, 2023
c1fa4c6
fix tests
SkafteNicki Feb 24, 2023
4a6d226
Merge branch 'master' into master
SkafteNicki Feb 24, 2023
7944620
Update tests/unittests/pairwise/test_pairwise_distance.py
SkafteNicki Feb 24, 2023
c206baa
Merge branch 'master' into master
mergify[bot] Feb 24, 2023
49f8438
Merge branch 'master' into master
mergify[bot] Feb 24, 2023
2583252
Merge branch 'master' into master
mergify[bot] Feb 24, 2023
eb9af2d
ruff: first line split + imperative mood (#1548
SkafteNicki Feb 24, 2023
e027831
Merge branch 'master' into master
mergify[bot] Feb 24, 2023
53566a9
ruff: first line split + imperative mood (#1548)
SkafteNicki Feb 24, 2023
d02807a
Merge branch 'master' into master
mergify[bot] Feb 24, 2023
12b3a46
Merge branch 'master' into master
mergify[bot] Feb 24, 2023
764dcb2
fix broken imports
SkafteNicki Feb 24, 2023
fd9053d
ruff
SkafteNicki Feb 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
MeanAbsolutePercentageError,
MeanSquaredError,
MeanSquaredLogError,
MinkowskiDistance,
PearsonCorrCoef,
R2Score,
SpearmanCorrCoef,
Expand Down
2 changes: 2 additions & 0 deletions src/torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from torchmetrics.functional.regression.log_mse import mean_squared_log_error
from torchmetrics.functional.regression.mae import mean_absolute_error
from torchmetrics.functional.regression.mape import mean_absolute_percentage_error
from torchmetrics.functional.regression.minkowski import minkowski_distance
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
from torchmetrics.functional.regression.mse import mean_squared_error
from torchmetrics.functional.regression.pearson import pearson_corrcoef
from torchmetrics.functional.regression.r2 import r2_score
Expand Down Expand Up @@ -133,6 +134,7 @@
"mean_absolute_percentage_error",
"mean_squared_error",
"mean_squared_log_error",
"minkowski_distance",
"multiscale_structural_similarity_index_measure",
"pairwise_cosine_similarity",
"pairwise_euclidean_distance",
Expand Down
50 changes: 50 additions & 0 deletions src/torchmetrics/functional/regression/minkowski.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import Optional
Borda marked this conversation as resolved.
Show resolved Hide resolved

import torch
from torch import Tensor

from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.exceptions import TorchMetricsUserError


def _minkowski_distance_update(preds: Tensor, targets: Tensor, p: float) -> Tensor:
"""Updates and returns variables required to compute Minkowski distance.

Checks for same shape of input tensors.

Args:
preds: Predicted tensor
target: Ground truth tensor
p: Non-negative number acting as the exponent to the errors
Borda marked this conversation as resolved.
Show resolved Hide resolved
"""
_check_same_shape(preds, targets)

if p < 0:
Borda marked this conversation as resolved.
Show resolved Hide resolved
raise TorchMetricsUserError("p value must be greater than 0")
Borda marked this conversation as resolved.
Show resolved Hide resolved

difference = torch.abs(preds - targets)
mink_dist_sum = torch.sum(torch.pow(difference, p))

return mink_dist_sum


def _minkowski_distance_compute(distance: Tensor, p: float) -> Tensor:
"""Computes Minkowski Distance.

Args:
distance: Sum of the p-th powers of errors over all observations
p: The non-negative numeric power the errors are to be raised to

Example:
>>> preds = torch.tensor([0., 1, 2, 3])
>>> target = torch.tensor([0., 2, 3, 1])
>>> distance_p_sum = _minkowski_distance_update(preds, target, 5)
>>> _minkowski_distance_compute(distance_p_sum, 5)
tensor(2.0244)
"""
return torch.pow(distance, 1.0 / p)


def minkowski_distance(preds: Tensor, targets: Tensor, p: float) -> Tensor:
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
minkowski_dist_sum = _minkowski_distance_update(preds, targets, p)
return _minkowski_distance_compute(minkowski_dist_sum, p)
1 change: 1 addition & 0 deletions src/torchmetrics/regression/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torchmetrics.regression.log_mse import MeanSquaredLogError # noqa: F401
from torchmetrics.regression.mae import MeanAbsoluteError # noqa: F401
from torchmetrics.regression.mape import MeanAbsolutePercentageError # noqa: F401
from torchmetrics.regression.minkowski import MinkowskiDistance # noqa: F401
from torchmetrics.regression.mse import MeanSquaredError # noqa: F401
from torchmetrics.regression.pearson import PearsonCorrCoef # noqa: F401
from torchmetrics.regression.r2 import R2Score # noqa: F401
Expand Down
63 changes: 63 additions & 0 deletions src/torchmetrics/regression/minkowski.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Optional

import torch
from torch import Tensor, tensor

from torchmetrics.functional.regression.minkowski import _minkowski_distance_compute, _minkowski_distance_update
from torchmetrics.metric import Metric


class MinkowskiDistance(Metric):
r"""Computes `Minkowski Distance`

.. math:: d_{\text{Minkowski}} = \\sum_{i}^N (| y_i - \\hat{y_i} |^p)^\frac{1}{p}

where
:math:`y` is a tensor of target values,
:math:`\\hat{y}` is a tensor of predictions,
:math: `\\p` is a non-negative integer or floating-point number

Args:
p: A non-negative number acting as the exponent in the calculation
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Example:
>>> import torch
>>> from torchmetrics import MeanSquaredError
Borda marked this conversation as resolved.
Show resolved Hide resolved
>>> target = torch.tensor([1.0, 2.8, 3.5, 4.5])
>>> preds = torch.tensor([6.1, 2.11, 3.1, 5.6])
>>> minkowski_distance = MinkowskiDistance(3)
>>> minkowski_distance(preds, target)
tensor(5.1220)
"""

is_differentiable: Optional[bool] = True
Borda marked this conversation as resolved.
Show resolved Hide resolved
minkowski_dist_sum: Tensor

def __init__(self, p: float, **kwargs: Any) -> None:
super().__init__(**kwargs)

self.add_state("minkowski_dist_sum", default=tensor(0.0))
self.p = p
Borda marked this conversation as resolved.
Show resolved Hide resolved

def update(self, preds: Tensor, targets: Tensor) -> None:
minkowski_dist_sum = _minkowski_distance_update(preds, targets, self.p)

self.minkowski_dist_sum += minkowski_dist_sum

def compute(self) -> Tensor:
return _minkowski_distance_compute(self.minkowski_dist_sum, self.p)
104 changes: 104 additions & 0 deletions tests/unittests/regression/test_minkowski_distance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from collections import namedtuple
from functools import partial

import numpy as np
import pytest
import torch
from scipy.spatial.distance import minkowski as scipy_minkowski
from sklearn.metrics import mean_squared_error

from torchmetrics.functional import minkowski_distance
from torchmetrics.regression import MinkowskiDistance
from torchmetrics.utilities.exceptions import TorchMetricsUserError
from unittests.helpers import seed_all
from unittests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester

seed_all(42)

num_targets = 5

Input = namedtuple("Input", ["preds", "target"])

_single_target_inputs = Input(preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.rand(NUM_BATCHES, BATCH_SIZE))

_multi_target_inputs = Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets)
)


def root_mean_squared_error(preds, target):
preds, target = preds.numpy(), target.numpy()
return np.sqrt(mean_squared_error(preds - target))


def _single_target_sk_metric(preds, target, sk_fn, metric_args):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()

res = sk_fn(sk_preds, sk_target, p=metric_args["p"])

return res


def _multi_target_sk_metric(preds, target, sk_fn, metric_args):
sk_preds = preds.view(-1, num_targets).numpy()
sk_target = target.view(-1, num_targets).numpy

res = sk_fn(sk_preds, sk_target, p=metric_args["p"])

return res


@pytest.mark.parametrize(
"preds, target, sk_metric",
[
(_single_target_inputs.preds, _single_target_inputs.target, _single_target_sk_metric),
],
)
@pytest.mark.parametrize(
"metric_class, metric_functional, sk_fn, metric_args",
[
(MinkowskiDistance, minkowski_distance, scipy_minkowski, {"p": 1}),
(MinkowskiDistance, minkowski_distance, scipy_minkowski, {"p": 2}),
(MinkowskiDistance, minkowski_distance, scipy_minkowski, {"p": 3}),
(MinkowskiDistance, minkowski_distance, scipy_minkowski, {"p": 4}),
(MinkowskiDistance, minkowski_distance, scipy_minkowski, {"p": 5}),
(MinkowskiDistance, minkowski_distance, scipy_minkowski, {"p": 0.5}),
(MinkowskiDistance, minkowski_distance, scipy_minkowski, {"p": 1.5}),
(MinkowskiDistance, minkowski_distance, scipy_minkowski, {"p": -1.25}),
(MinkowskiDistance, minkowski_distance, scipy_minkowski, {"p": -0.5}),
(MinkowskiDistance, minkowski_distance, scipy_minkowski, {"p": -8}),
],
)
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
class TestMinkowskiDistance(MetricTester):
@pytest.mark.parametrize("ddp", [False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_minkowski_distance_class(
self, preds, target, sk_metric, metric_class, metric_functional, sk_fn, metric_args, ddp, dist_sync_on_step
):
if metric_args["p"] < 0:
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
pytest.xfail("p-value must not be less than 0")

self.run_class_metric_test(
ddp=ddp,
preds=preds,
target=target,
metric_class=metric_class,
sk_metric=partial(sk_metric, sk_fn=sk_fn, metric_args=metric_args),
dist_sync_on_step=dist_sync_on_step,
metric_args=metric_args,
)

def test_minkowski_distance_functional(
self, preds, target, sk_metric, metric_class, metric_functional, sk_fn, metric_args
):
if metric_args["p"] < 0:
pytest.xfail("p-value must not be less than 0")

self.run_functional_metric_test(
preds=preds,
target=target,
metric_functional=metric_functional,
sk_metric=partial(sk_metric, sk_fn=sk_fn, metric_args=metric_args),
metric_args=metric_args,
)