Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pairwise subpackage #553

Merged
merged 37 commits into from
Oct 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
a000c41
start
SkafteNicki Jul 17, 2021
afedcd6
deprecate
SkafteNicki Jul 17, 2021
38e615d
init
SkafteNicki Jul 17, 2021
23937b5
changes
SkafteNicki Jul 23, 2021
ed17876
Merge branch 'master' into pairwise
SkafteNicki Sep 27, 2021
94484ae
something working
SkafteNicki Sep 27, 2021
914c966
update
SkafteNicki Sep 30, 2021
19ff511
Merge branch 'master' into pairwise
SkafteNicki Sep 30, 2021
7208899
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 30, 2021
cb6ccb6
Apply suggestions from code review
Borda Sep 30, 2021
f7c4d50
Merge branch 'master' into pairwise
SkafteNicki Oct 1, 2021
463a6bd
change of variable names
SkafteNicki Oct 1, 2021
3b39627
move check input
SkafteNicki Oct 1, 2021
8168358
16 bit testing
SkafteNicki Oct 1, 2021
5a33d53
move to helper
SkafteNicki Oct 1, 2021
fe56370
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 1, 2021
6e4d451
fix flake
SkafteNicki Oct 1, 2021
664e202
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 1, 2021
a98aec5
mypy
SkafteNicki Oct 1, 2021
f9ed5b8
Merge branch 'pairwise' of https://github.com/PyTorchLightning/metric…
SkafteNicki Oct 1, 2021
73e04ac
Merge branch 'master' into pairwise
SkafteNicki Oct 7, 2021
c85451a
skip gpu test
SkafteNicki Oct 7, 2021
aacf41b
Merge branch 'pairwise' of https://github.com/PyTorchLightning/metric…
SkafteNicki Oct 7, 2021
78ffd72
fix test
SkafteNicki Oct 8, 2021
8d518d9
update
SkafteNicki Oct 8, 2021
b58ac2b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 8, 2021
168617b
atol
SkafteNicki Oct 8, 2021
1182aa1
Merge branch 'pairwise' of https://github.com/PyTorchLightning/metric…
SkafteNicki Oct 8, 2021
6093606
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 8, 2021
3ab815e
Merge branch 'master' into pairwise
SkafteNicki Oct 11, 2021
1a4c432
changelog
SkafteNicki Oct 11, 2021
b7f699e
Merge branch 'pairwise' of https://github.com/PyTorchLightning/metric…
SkafteNicki Oct 11, 2021
cc2e71d
suggestions
SkafteNicki Oct 11, 2021
49ef598
Merge branch 'master' into pairwise
Borda Oct 13, 2021
413ee37
Merge branch 'master' into pairwise
SkafteNicki Oct 13, 2021
9bc02bb
Apply suggestions from code review
SkafteNicki Oct 14, 2021
51b81b4
Merge branch 'master' into pairwise
Borda Oct 14, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added simple aggregation metrics: `SumMetric`, `MeanMetric`, `CatMetric`, `MinMetric`, `MaxMetric` ([#506](https://github.com/PyTorchLightning/metrics/pull/506))


- Added pairwise submodule with metrics ([#553](https://github.com/PyTorchLightning/metrics/pull/553))
- `pairwise_cosine_similarity`
- `pairwise_euclidean_distance`
- `pairwise_linear_similarity`
- `pairwise_manhatten_distance`


- Added Short Term Objective Intelligibility (`STOI`) ([#353](https://github.com/PyTorchLightning/metrics/issues/353))


Expand All @@ -55,6 +62,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Deprecated

- Deprecated `torchmetrics.functional.self_supervised.embedding_similarity` in favour of new pairwise submodule

### Removed

Expand Down
34 changes: 28 additions & 6 deletions docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -328,16 +328,38 @@ tweedie_deviance_score [func]
:noindex:


********
Pairwise
********
****************
Pairwise Metrics
****************

embedding_similarity [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~~
pairwise_cosine_similarity [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.embedding_similarity
.. autofunction:: torchmetrics.functional.pairwise_cosine_similarity
:noindex:


pairwise_euclidean_distance [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.pairwise_euclidean_distance
:noindex:


pairwise_linear_similarity [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.pairwise_linear_similarity
:noindex:


pairwise_manhatten_distance [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.pairwise_manhatten_distance
:noindex:


*********
Retrieval
*********
Expand Down
25 changes: 18 additions & 7 deletions tests/helpers/testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def _functional_test(


def _assert_half_support(
metric_module: Metric,
metric_module: Optional[Metric],
metric_functional: Optional[Callable],
preds: Tensor,
target: Tensor,
Expand All @@ -286,8 +286,9 @@ def _assert_half_support(
k: (v[0].half() if v.is_floating_point() else v[0]).to(device) if isinstance(v, Tensor) else v
for k, v in kwargs_update.items()
}
metric_module = metric_module.to(device)
_assert_tensor(metric_module(y_hat, y, **kwargs_update))
if metric_module is not None:
metric_module = metric_module.to(device)
_assert_tensor(metric_module(y_hat, y, **kwargs_update))
if metric_functional is not None:
_assert_tensor(metric_functional(y_hat, y, **kwargs_update))

Expand Down Expand Up @@ -436,7 +437,7 @@ def run_class_metric_test(
def run_precision_test_cpu(
preds: Tensor,
target: Tensor,
metric_module: Metric,
metric_module: Optional[Metric] = None,
metric_functional: Optional[Callable] = None,
metric_args: Optional[dict] = None,
**kwargs_update,
Expand All @@ -453,14 +454,19 @@ def run_precision_test_cpu(
"""
metric_args = metric_args or {}
_assert_half_support(
metric_module(**metric_args), metric_functional, preds, target, device="cpu", **kwargs_update
metric_module(**metric_args) if metric_module is not None else None,
metric_functional,
preds,
target,
device="cpu",
**kwargs_update,
)

@staticmethod
def run_precision_test_gpu(
preds: Tensor,
target: Tensor,
metric_module: Metric,
metric_module: Optional[Metric] = None,
metric_functional: Optional[Callable] = None,
metric_args: Optional[dict] = None,
**kwargs_update,
Expand All @@ -477,7 +483,12 @@ def run_precision_test_gpu(
"""
metric_args = metric_args or {}
_assert_half_support(
metric_module(**metric_args), metric_functional, preds, target, device="cuda", **kwargs_update
metric_module(**metric_args) if metric_module is not None else None,
metric_functional,
preds,
target,
device="cuda",
**kwargs_update,
)

@staticmethod
Expand Down
Empty file added tests/pairwise/__init__.py
Empty file.
121 changes: 121 additions & 0 deletions tests/pairwise/test_pairwise_distance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
from functools import partial

import pytest
import torch
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances, linear_kernel, manhattan_distances

from tests.helpers import seed_all
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester
from torchmetrics.functional import (
pairwise_cosine_similarity,
pairwise_euclidean_distance,
pairwise_linear_similarity,
pairwise_manhatten_distance,
)
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_7

seed_all(42)

extra_dim = 5

Input = namedtuple("Input", ["x", "y"])


_inputs1 = Input(
x=torch.rand(NUM_BATCHES, BATCH_SIZE, extra_dim),
y=torch.rand(NUM_BATCHES, BATCH_SIZE, extra_dim),
)


_inputs2 = Input(
x=torch.rand(NUM_BATCHES, BATCH_SIZE, extra_dim),
y=torch.rand(NUM_BATCHES, BATCH_SIZE, extra_dim),
)


def _sk_metric(x, y, sk_fn, reduction):
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
"""comparison function."""
x = x.view(-1, extra_dim).numpy()
y = y.view(-1, extra_dim).numpy()
res = sk_fn(x, y)
if reduction == "sum":
return res.sum(axis=-1)
elif reduction == "mean":
return res.mean(axis=-1)
return res


@pytest.mark.parametrize(
"x, y",
[
(_inputs1.x, _inputs1.y),
(_inputs2.x, _inputs2.y),
],
)
@pytest.mark.parametrize(
"metric_functional, sk_fn",
[
(pairwise_cosine_similarity, cosine_similarity),
(pairwise_euclidean_distance, euclidean_distances),
(pairwise_manhatten_distance, manhattan_distances),
(pairwise_linear_similarity, linear_kernel),
],
)
@pytest.mark.parametrize("reduction", ["sum", "mean", None])
class TestPairwise(MetricTester):
"""test pairwise implementations."""

atol = 1e-4

def test_pairwise_functional(self, x, y, metric_functional, sk_fn, reduction):
"""test functional pairwise implementations."""
self.run_functional_metric_test(
preds=x,
target=y,
metric_functional=metric_functional,
sk_metric=partial(_sk_metric, sk_fn=sk_fn, reduction=reduction),
metric_args={"reduction": reduction},
)

@pytest.mark.skipif(
not _TORCH_GREATER_EQUAL_1_7, reason="half support of core operations on not support before pytorch v1.7"
)
def test_pairwise_half_cpu(self, x, y, metric_functional, sk_fn, reduction):
"""test half precision support on cpu."""
if metric_functional == pairwise_euclidean_distance:
pytest.xfail("pairwise_euclidean_distance metric does not support cpu + half precision")
self.run_precision_test_cpu(x, y, None, metric_functional, metric_args={"reduction": reduction})

@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda")
def test_pairwise_half_gpu(self, x, y, metric_functional, sk_fn, reduction):
"""test half precision support on gpu."""
self.run_precision_test_gpu(x, y, None, metric_functional, metric_args={"reduction": reduction})


@pytest.mark.parametrize(
"metric", [pairwise_cosine_similarity, pairwise_euclidean_distance, pairwise_manhatten_distance]
)
def test_error_on_wrong_shapes(metric):
"""Test errors are raised on wrong input."""
with pytest.raises(ValueError, match="Expected argument `x` to be a 2D tensor .*"):
metric(torch.randn(10))

with pytest.raises(ValueError, match="Expected argument `y` to be a 2D tensor .*"):
metric(torch.randn(10, 5), torch.randn(5, 3))

with pytest.raises(ValueError, match="Expected reduction to be one of .*"):
metric(torch.randn(10, 5), torch.randn(10, 5), reduction=1)
8 changes: 8 additions & 0 deletions torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@
from torchmetrics.functional.image.gradients import image_gradients
from torchmetrics.functional.image.psnr import psnr
from torchmetrics.functional.image.ssim import ssim
from torchmetrics.functional.pairwise.cosine import pairwise_cosine_similarity
from torchmetrics.functional.pairwise.euclidean import pairwise_euclidean_distance
from torchmetrics.functional.pairwise.linear import pairwise_linear_similarity
from torchmetrics.functional.pairwise.manhatten import pairwise_manhatten_distance
from torchmetrics.functional.regression.cosine_similarity import cosine_similarity
from torchmetrics.functional.regression.explained_variance import explained_variance
from torchmetrics.functional.regression.mean_absolute_error import mean_absolute_error
Expand Down Expand Up @@ -93,6 +97,10 @@
"mean_absolute_percentage_error",
"mean_squared_error",
"mean_squared_log_error",
"pairwise_cosine_similarity",
"pairwise_euclidean_distance",
"pairwise_linear_similarity",
"pairwise_manhatten_distance",
"pearson_corrcoef",
"pesq",
"pit",
Expand Down
17 changes: 17 additions & 0 deletions torchmetrics/functional/pairwise/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.functional.pairwise.cosine import pairwise_cosine_similarity # noqa: F401
from torchmetrics.functional.pairwise.euclidean import pairwise_euclidean_distance # noqa: F401
from torchmetrics.functional.pairwise.linear import pairwise_linear_similarity # noqa: F401
from torchmetrics.functional.pairwise.manhatten import pairwise_manhatten_distance # noqa: F401
85 changes: 85 additions & 0 deletions torchmetrics/functional/pairwise/cosine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional

import torch
from torch import Tensor

from torchmetrics.functional.pairwise.helpers import _check_input, _reduce_distance_matrix


def _pairwise_cosine_similarity_update(
x: Tensor, y: Optional[Tensor] = None, zero_diagonal: Optional[bool] = None
) -> Tensor:
"""Calculates the pairwise cosine similarity matrix.

Args:
x: tensor of shape ``[N,d]``
y: tensor of shape ``[M,d]``
zero_diagonal: determines if the diagonal of the distance matrix should be set to zero
"""
x, y, zero_diagonal = _check_input(x, y, zero_diagonal)

norm = torch.norm(x, p=2, dim=1)
x /= norm.unsqueeze(1)
norm = torch.norm(y, p=2, dim=1)
y /= norm.unsqueeze(1)

distance = x @ y.T
if zero_diagonal:
distance.fill_diagonal_(0)
return distance


def pairwise_cosine_similarity(
x: Tensor, y: Optional[Tensor] = None, reduction: Optional[str] = None, zero_diagonal: Optional[bool] = None
) -> Tensor:
r"""
Calculates pairwise cosine similarity:

.. math::
s_{cos}(x,y) = \frac{<x,y>}{||x|| \cdot ||y||}
= \frac{\sum_{d=1}^D x_d \cdot y_d }{\sqrt{\sum_{d=1}^D x_i^2} \cdot \sqrt{\sum_{d=1}^D x_i^2}}

If both `x` and `y` are passed in, the calculation will be performed pairwise between the rows of `x` and `y`.
If only `x` is passed in, the calculation will be performed between the rows of `x`.

Args:
x: Tensor with shape ``[N, d]``
y: Tensor with shape ``[M, d]``, optional
reduction: reduction to apply along the last dimension. Choose between `'mean'`, `'sum'`
(applied along column dimension) or `'none'`, `None` for no reduction
zero_diagonal: if the diagonal of the distance matrix should be set to 0. If only `x` is given
this defaults to `True` else if `y` is also given it defaults to `False`

Returns:
A ``[N,N]`` matrix of distances if only ``x`` is given, else a ``[N,M]`` matrix

Example:
>>> import torch
>>> from torchmetrics.functional import pairwise_cosine_similarity
>>> x = torch.tensor([[2, 3], [3, 5], [5, 8]], dtype=torch.float32)
>>> y = torch.tensor([[1, 0], [2, 1]], dtype=torch.float32)
>>> pairwise_cosine_similarity(x, y)
tensor([[0.5547, 0.8682],
[0.5145, 0.8437],
[0.5300, 0.8533]])
>>> pairwise_cosine_similarity(x)
tensor([[0.0000, 0.9989, 0.9996],
[0.9989, 0.0000, 0.9998],
[0.9996, 0.9998, 0.0000]])

"""
distance = _pairwise_cosine_similarity_update(x, y, zero_diagonal)
return _reduce_distance_matrix(distance, reduction)
Loading