Skip to content

Commit

Permalink
Add new metrics: SAM (#885)
Browse files Browse the repository at this point in the history
* Added new image metric - SAM

* Added new image metric - SAM

* revise commit indent

* revise library

* revise docs string

* update docs ang changelog

* revise conflict with master

* revise compute correctly spectral angle, change functional name

* Apply suggestions from code review

* wording

* add citation

* fix move dim with lower version of pytorch

* fix old PT

* fix old PT torch clip

* Update torchmetrics/image/sam.py

* Update torchmetrics/functional/image/sam.py

* Update torchmetrics/image/sam.py

* Update torchmetrics/image/sam.py

* move reference test function, change docs

* resolve conflict with master

* fix changelog

* fix docs

* fix mypy

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Nicki Skafte Detlefsen <skaftenicki@gmail.com>
  • Loading branch information
6 people authored Mar 21, 2022
1 parent 865a08f commit f144425
Show file tree
Hide file tree
Showing 11 changed files with 353 additions and 0 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Added new image metric `SpectralAngleMapper` ([#885](https://github.com/PyTorchLightning/metrics/pull/885))


- Added `CoverageError` to classification metrics ([#787](https://github.com/PyTorchLightning/metrics/pull/787))

Expand Down
8 changes: 8 additions & 0 deletions docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,14 @@ peak_signal_noise_ratio [func]
.. autofunction:: torchmetrics.functional.peak_signal_noise_ratio
:noindex:


spectral_angle_mapper [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.spectral_angle_mapper
:noindex:


universal_image_quality_index [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
6 changes: 6 additions & 0 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,12 @@ PeakSignalNoiseRatio
.. autoclass:: torchmetrics.PeakSignalNoiseRatio
:noindex:

SpectralAngleMapper
~~~~~~~~~~~~~~~~~~~

.. autoclass:: torchmetrics.SpectralAngleMapper
:noindex:

StructuralSimilarityIndexMeasure
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
15 changes: 15 additions & 0 deletions tests/helpers/reference_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from typing import Optional, Union

import numpy as np
import torch
import torch.nn.functional as F
from sklearn.metrics._regression import _check_reg_targets
from sklearn.utils import assert_all_finite, check_consistent_length, column_or_1d

Expand Down Expand Up @@ -198,3 +200,16 @@ def _calibration_error(
loss += np.sum(np.nan_to_num(debias))
loss = np.sqrt(max(loss, 0.0))
return loss


def _sk_sam(preds, target, reduction):
similarity = F.cosine_similarity(preds, target)
sam_score = torch.clamp(similarity, -1, 1).acos()
# reduction
if reduction == "sum":
to_return = torch.sum(sam_score)
elif reduction == "elementwise_mean":
to_return = torch.mean(sam_score)
else:
to_return = sam_score
return to_return
106 changes: 106 additions & 0 deletions tests/image/test_sam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
from functools import partial

import pytest
import torch

from tests.helpers import seed_all
from tests.helpers.reference_metrics import _sk_sam
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester
from torchmetrics.functional.image.sam import spectral_angle_mapper
from torchmetrics.image.sam import SpectralAngleMapper

seed_all(42)

Input = namedtuple("Input", ["preds", "target"])

_inputs = []
for size, channel, dtype in [
(12, 3, torch.float),
(13, 3, torch.float32),
(14, 3, torch.double),
(15, 3, torch.float64),
]:
preds = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size, size, dtype=dtype)
target = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size, size, dtype=dtype)
_inputs.append(Input(preds=preds, target=target))


@pytest.mark.parametrize("reduction", ["sum", "elementwise_mean"])
@pytest.mark.parametrize(
"preds, target",
[(i.preds, i.target) for i in _inputs],
)
class TestSpectralAngleMapper(MetricTester):
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_sam(self, reduction, preds, target, ddp, dist_sync_on_step):
self.run_class_metric_test(
ddp,
preds,
target,
SpectralAngleMapper,
partial(_sk_sam, reduction=reduction),
dist_sync_on_step,
metric_args=dict(reduction=reduction),
)

def test_sam_functional(self, reduction, preds, target):
self.run_functional_metric_test(
preds,
target,
spectral_angle_mapper,
partial(_sk_sam, reduction=reduction),
metric_args=dict(reduction=reduction),
)

# SAM half + cpu does not work due to missing support in torch.log
@pytest.mark.xfail(reason="SAM metric does not support cpu + half precision")
def test_sam_half_cpu(self, reduction, preds, target):
self.run_precision_test_cpu(
preds,
target,
SpectralAngleMapper,
spectral_angle_mapper,
)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda")
def test_sam_half_gpu(self, reduction, preds, target):
self.run_precision_test_gpu(preds, target, SpectralAngleMapper, spectral_angle_mapper)


def test_error_on_different_shape(metric_class=SpectralAngleMapper):
metric = metric_class()
with pytest.raises(RuntimeError):
metric(torch.randn([1, 3, 16, 16]), torch.randn([1, 1, 16, 16]))


def test_error_on_invalid_shape(metric_class=SpectralAngleMapper):
metric = metric_class()
with pytest.raises(ValueError):
metric(torch.randn([3, 16, 16]), torch.randn([3, 16, 16]))


def test_error_on_invalid_type(metric_class=SpectralAngleMapper):
metric = metric_class()
with pytest.raises(TypeError):
metric(torch.randn([3, 16, 16]), torch.randn([3, 16, 16], dtype=torch.float64))


def test_error_on_grayscale_image(metric_class=SpectralAngleMapper):
metric = metric_class()
with pytest.raises(ValueError):
metric(torch.randn([16, 1, 16, 16]), torch.randn([16, 1, 16, 16]))
2 changes: 2 additions & 0 deletions torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from torchmetrics.image import ( # noqa: E402
MultiScaleStructuralSimilarityIndexMeasure,
PeakSignalNoiseRatio,
SpectralAngleMapper,
StructuralSimilarityIndexMeasure,
UniversalImageQualityIndex,
)
Expand Down Expand Up @@ -167,6 +168,7 @@
"SignalNoiseRatio",
"SpearmanCorrCoef",
"Specificity",
"SpectralAngleMapper",
"SQuAD",
"StructuralSimilarityIndexMeasure",
"StatScores",
Expand Down
2 changes: 2 additions & 0 deletions torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from torchmetrics.functional.classification.stat_scores import stat_scores
from torchmetrics.functional.image.gradients import image_gradients
from torchmetrics.functional.image.psnr import peak_signal_noise_ratio
from torchmetrics.functional.image.sam import spectral_angle_mapper
from torchmetrics.functional.image.ssim import (
multiscale_structural_similarity_index_measure,
structural_similarity_index_measure,
Expand Down Expand Up @@ -152,6 +153,7 @@
"symmetric_mean_absolute_percentage_error",
"translation_edit_rate",
"universal_image_quality_index",
"spectral_angle_mapper",
"word_error_rate",
"char_error_rate",
"match_error_rate",
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/functional/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from torchmetrics.functional.image.gradients import image_gradients # noqa: F401
from torchmetrics.functional.image.psnr import peak_signal_noise_ratio # noqa: F401
from torchmetrics.functional.image.sam import spectral_angle_mapper # noqa: F401
from torchmetrics.functional.image.ssim import ( # noqa: F401
multiscale_structural_similarity_index_measure,
structural_similarity_index_measure,
Expand Down
119 changes: 119 additions & 0 deletions torchmetrics/functional/image/sam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple

import torch
from torch import Tensor
from typing_extensions import Literal

from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.distributed import reduce


def _sam_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
"""Updates and returns variables required to compute Spectral Angle Mapper. Checks for same shape and type of
the input tensors.
Args:
preds: Predicted tensor
target: Ground truth tensor
"""

if preds.dtype != target.dtype:
raise TypeError(
"Expected `preds` and `target` to have the same data type."
f" Got preds: {preds.dtype} and target: {target.dtype}."
)
_check_same_shape(preds, target)
if len(preds.shape) != 4:
raise ValueError(
"Expected `preds` and `target` to have BxCxHxW shape."
f" Got preds: {preds.shape} and target: {target.shape}."
)
if (preds.shape[1] <= 1) or (target.shape[1] <= 1):
raise ValueError(
"Expected channel dimension of `preds` and `target` to be larger than 1."
f" Got preds: {preds.shape[1]} and target: {target.shape[1]}."
)
return preds, target


def _sam_compute(
preds: Tensor,
target: Tensor,
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
) -> Tensor:
"""Computes Spectral Angle Mapper.
Args:
preds: estimated image
target: ground truth image
reduction: a method to reduce metric score over labels.
- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'`` or ``None``: no reduction will be applied
Example:
>>> preds = torch.rand([16, 3, 16, 16], generator=torch.manual_seed(42))
>>> target = torch.rand([16, 3, 16, 16], generator=torch.manual_seed(123))
>>> preds, target = _sam_update(preds, target)
>>> _sam_compute(preds, target)
tensor(0.5943)
"""
dot_product = (preds * target).sum(dim=1)
preds_norm = preds.norm(dim=1)
target_norm = target.norm(dim=1)
sam_score = torch.clamp(dot_product / (preds_norm * target_norm), -1, 1).acos()
return reduce(sam_score, reduction)


def spectral_angle_mapper(
preds: Tensor,
target: Tensor,
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
) -> Tensor:
"""Universal Spectral Angle Mapper.
Args:
preds: estimated image
target: ground truth image
reduction: a method to reduce metric score over labels.
- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'`` or ``None``: no reduction will be applied
Return:
Tensor with Spectral Angle Mapper score
Raises:
TypeError:
If ``preds`` and ``target`` don't have the same data type.
ValueError:
If ``preds`` and ``target`` don't have ``BxCxHxW shape``.
Example:
>>> from torchmetrics.functional import spectral_angle_mapper
>>> preds = torch.rand([16, 3, 16, 16], generator=torch.manual_seed(42))
>>> target = torch.rand([16, 3, 16, 16], generator=torch.manual_seed(123))
>>> spectral_angle_mapper(preds, target)
tensor(0.5943)
References: Roberta H. Yuhas, Alexander F. H. Goetz and Joe W. Boardman, "Discrimination among semi-arid
landscape endmembers using the Spectral Angle Mapper (SAM) algorithm" in PL, Summaries of the Third Annual JPL
Airborne Geoscience Workshop, vol. 1, June 1, 1992.
"""
preds, target = _sam_update(preds, target)
return _sam_compute(preds, target, reduction)
1 change: 1 addition & 0 deletions torchmetrics/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.image.psnr import PeakSignalNoiseRatio # noqa: F401
from torchmetrics.image.sam import SpectralAngleMapper # noqa: F401
from torchmetrics.image.ssim import ( # noqa: F401
MultiScaleStructuralSimilarityIndexMeasure,
StructuralSimilarityIndexMeasure,
Expand Down
Loading

0 comments on commit f144425

Please sign in to comment.