-
Notifications
You must be signed in to change notification settings - Fork 402
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Added new image metric - SAM * Added new image metric - SAM * revise commit indent * revise library * revise docs string * update docs ang changelog * revise conflict with master * revise compute correctly spectral angle, change functional name * Apply suggestions from code review * wording * add citation * fix move dim with lower version of pytorch * fix old PT * fix old PT torch clip * Update torchmetrics/image/sam.py * Update torchmetrics/functional/image/sam.py * Update torchmetrics/image/sam.py * Update torchmetrics/image/sam.py * move reference test function, change docs * resolve conflict with master * fix changelog * fix docs * fix mypy Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Nicki Skafte Detlefsen <skaftenicki@gmail.com>
- Loading branch information
1 parent
865a08f
commit f144425
Showing
11 changed files
with
353 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from collections import namedtuple | ||
from functools import partial | ||
|
||
import pytest | ||
import torch | ||
|
||
from tests.helpers import seed_all | ||
from tests.helpers.reference_metrics import _sk_sam | ||
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester | ||
from torchmetrics.functional.image.sam import spectral_angle_mapper | ||
from torchmetrics.image.sam import SpectralAngleMapper | ||
|
||
seed_all(42) | ||
|
||
Input = namedtuple("Input", ["preds", "target"]) | ||
|
||
_inputs = [] | ||
for size, channel, dtype in [ | ||
(12, 3, torch.float), | ||
(13, 3, torch.float32), | ||
(14, 3, torch.double), | ||
(15, 3, torch.float64), | ||
]: | ||
preds = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size, size, dtype=dtype) | ||
target = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size, size, dtype=dtype) | ||
_inputs.append(Input(preds=preds, target=target)) | ||
|
||
|
||
@pytest.mark.parametrize("reduction", ["sum", "elementwise_mean"]) | ||
@pytest.mark.parametrize( | ||
"preds, target", | ||
[(i.preds, i.target) for i in _inputs], | ||
) | ||
class TestSpectralAngleMapper(MetricTester): | ||
@pytest.mark.parametrize("ddp", [True, False]) | ||
@pytest.mark.parametrize("dist_sync_on_step", [True, False]) | ||
def test_sam(self, reduction, preds, target, ddp, dist_sync_on_step): | ||
self.run_class_metric_test( | ||
ddp, | ||
preds, | ||
target, | ||
SpectralAngleMapper, | ||
partial(_sk_sam, reduction=reduction), | ||
dist_sync_on_step, | ||
metric_args=dict(reduction=reduction), | ||
) | ||
|
||
def test_sam_functional(self, reduction, preds, target): | ||
self.run_functional_metric_test( | ||
preds, | ||
target, | ||
spectral_angle_mapper, | ||
partial(_sk_sam, reduction=reduction), | ||
metric_args=dict(reduction=reduction), | ||
) | ||
|
||
# SAM half + cpu does not work due to missing support in torch.log | ||
@pytest.mark.xfail(reason="SAM metric does not support cpu + half precision") | ||
def test_sam_half_cpu(self, reduction, preds, target): | ||
self.run_precision_test_cpu( | ||
preds, | ||
target, | ||
SpectralAngleMapper, | ||
spectral_angle_mapper, | ||
) | ||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") | ||
def test_sam_half_gpu(self, reduction, preds, target): | ||
self.run_precision_test_gpu(preds, target, SpectralAngleMapper, spectral_angle_mapper) | ||
|
||
|
||
def test_error_on_different_shape(metric_class=SpectralAngleMapper): | ||
metric = metric_class() | ||
with pytest.raises(RuntimeError): | ||
metric(torch.randn([1, 3, 16, 16]), torch.randn([1, 1, 16, 16])) | ||
|
||
|
||
def test_error_on_invalid_shape(metric_class=SpectralAngleMapper): | ||
metric = metric_class() | ||
with pytest.raises(ValueError): | ||
metric(torch.randn([3, 16, 16]), torch.randn([3, 16, 16])) | ||
|
||
|
||
def test_error_on_invalid_type(metric_class=SpectralAngleMapper): | ||
metric = metric_class() | ||
with pytest.raises(TypeError): | ||
metric(torch.randn([3, 16, 16]), torch.randn([3, 16, 16], dtype=torch.float64)) | ||
|
||
|
||
def test_error_on_grayscale_image(metric_class=SpectralAngleMapper): | ||
metric = metric_class() | ||
with pytest.raises(ValueError): | ||
metric(torch.randn([16, 1, 16, 16]), torch.randn([16, 1, 16, 16])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Tuple | ||
|
||
import torch | ||
from torch import Tensor | ||
from typing_extensions import Literal | ||
|
||
from torchmetrics.utilities.checks import _check_same_shape | ||
from torchmetrics.utilities.distributed import reduce | ||
|
||
|
||
def _sam_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: | ||
"""Updates and returns variables required to compute Spectral Angle Mapper. Checks for same shape and type of | ||
the input tensors. | ||
Args: | ||
preds: Predicted tensor | ||
target: Ground truth tensor | ||
""" | ||
|
||
if preds.dtype != target.dtype: | ||
raise TypeError( | ||
"Expected `preds` and `target` to have the same data type." | ||
f" Got preds: {preds.dtype} and target: {target.dtype}." | ||
) | ||
_check_same_shape(preds, target) | ||
if len(preds.shape) != 4: | ||
raise ValueError( | ||
"Expected `preds` and `target` to have BxCxHxW shape." | ||
f" Got preds: {preds.shape} and target: {target.shape}." | ||
) | ||
if (preds.shape[1] <= 1) or (target.shape[1] <= 1): | ||
raise ValueError( | ||
"Expected channel dimension of `preds` and `target` to be larger than 1." | ||
f" Got preds: {preds.shape[1]} and target: {target.shape[1]}." | ||
) | ||
return preds, target | ||
|
||
|
||
def _sam_compute( | ||
preds: Tensor, | ||
target: Tensor, | ||
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", | ||
) -> Tensor: | ||
"""Computes Spectral Angle Mapper. | ||
Args: | ||
preds: estimated image | ||
target: ground truth image | ||
reduction: a method to reduce metric score over labels. | ||
- ``'elementwise_mean'``: takes the mean (default) | ||
- ``'sum'``: takes the sum | ||
- ``'none'`` or ``None``: no reduction will be applied | ||
Example: | ||
>>> preds = torch.rand([16, 3, 16, 16], generator=torch.manual_seed(42)) | ||
>>> target = torch.rand([16, 3, 16, 16], generator=torch.manual_seed(123)) | ||
>>> preds, target = _sam_update(preds, target) | ||
>>> _sam_compute(preds, target) | ||
tensor(0.5943) | ||
""" | ||
dot_product = (preds * target).sum(dim=1) | ||
preds_norm = preds.norm(dim=1) | ||
target_norm = target.norm(dim=1) | ||
sam_score = torch.clamp(dot_product / (preds_norm * target_norm), -1, 1).acos() | ||
return reduce(sam_score, reduction) | ||
|
||
|
||
def spectral_angle_mapper( | ||
preds: Tensor, | ||
target: Tensor, | ||
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", | ||
) -> Tensor: | ||
"""Universal Spectral Angle Mapper. | ||
Args: | ||
preds: estimated image | ||
target: ground truth image | ||
reduction: a method to reduce metric score over labels. | ||
- ``'elementwise_mean'``: takes the mean (default) | ||
- ``'sum'``: takes the sum | ||
- ``'none'`` or ``None``: no reduction will be applied | ||
Return: | ||
Tensor with Spectral Angle Mapper score | ||
Raises: | ||
TypeError: | ||
If ``preds`` and ``target`` don't have the same data type. | ||
ValueError: | ||
If ``preds`` and ``target`` don't have ``BxCxHxW shape``. | ||
Example: | ||
>>> from torchmetrics.functional import spectral_angle_mapper | ||
>>> preds = torch.rand([16, 3, 16, 16], generator=torch.manual_seed(42)) | ||
>>> target = torch.rand([16, 3, 16, 16], generator=torch.manual_seed(123)) | ||
>>> spectral_angle_mapper(preds, target) | ||
tensor(0.5943) | ||
References: Roberta H. Yuhas, Alexander F. H. Goetz and Joe W. Boardman, "Discrimination among semi-arid | ||
landscape endmembers using the Spectral Angle Mapper (SAM) algorithm" in PL, Summaries of the Third Annual JPL | ||
Airborne Geoscience Workshop, vol. 1, June 1, 1992. | ||
""" | ||
preds, target = _sam_update(preds, target) | ||
return _sam_compute(preds, target, reduction) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.