Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new metrics: SAM #885

Merged
merged 36 commits into from
Mar 21, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
f0ea2e3
Added new image metric - SAM
vumichien Mar 11, 2022
ab82404
Added new image metric - SAM
vumichien Mar 11, 2022
925caa1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 11, 2022
74a3d31
revise commit indent
vumichien Mar 11, 2022
440632c
Merge remote-tracking branch 'origin/feature/sam' into feature/sam
vumichien Mar 11, 2022
02700c2
revise library
vumichien Mar 11, 2022
d4fb716
revise docs string
vumichien Mar 11, 2022
9a6d02a
Merge branch 'master' into feature/sam
Borda Mar 11, 2022
0c95581
update docs ang changelog
vumichien Mar 11, 2022
c066b2e
revise conflict with master
vumichien Mar 12, 2022
9c63be6
revise compute correctly spectral angle, change functional name
vumichien Mar 13, 2022
d5cdf3c
Apply suggestions from code review
Borda Mar 18, 2022
24c7d77
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 18, 2022
5167ac4
Merge branch 'master' into feature/sam
mergify[bot] Mar 19, 2022
4c2d55a
wording
Borda Mar 19, 2022
9f1e17a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 19, 2022
cad085a
Merge branch 'master' into feature/sam
mergify[bot] Mar 19, 2022
ff6026d
add citation
vumichien Mar 20, 2022
197c362
Merge branch 'master' into feature/sam
mergify[bot] Mar 20, 2022
73096f2
fix move dim with lower version of pytorch
vumichien Mar 20, 2022
ac63809
fix old PT
vumichien Mar 20, 2022
b18abb4
fix old PT torch clip
vumichien Mar 20, 2022
608f056
Update torchmetrics/image/sam.py
vumichien Mar 21, 2022
5e5e455
Update torchmetrics/functional/image/sam.py
vumichien Mar 21, 2022
8ee687d
Update torchmetrics/image/sam.py
vumichien Mar 21, 2022
0e6e49b
Update torchmetrics/image/sam.py
vumichien Mar 21, 2022
bc9dc03
move reference test function, change docs
vumichien Mar 21, 2022
743725c
resolve conflict with master
vumichien Mar 21, 2022
60aa750
Merge branch 'master' into feature/sam
Borda Mar 21, 2022
0f4b419
fix
vumichien Mar 21, 2022
768fd04
Merge remote-tracking branch 'origin/feature/sam' into feature/sam
vumichien Mar 21, 2022
d9cc65a
fix changelog
SkafteNicki Mar 21, 2022
bfe12fe
fix docs
vumichien Mar 21, 2022
ac1b9f8
fix mypy
SkafteNicki Mar 21, 2022
9aafd37
Merge branch 'feature/sam' of https://github.com/vumichien/metrics in…
SkafteNicki Mar 21, 2022
bb917de
fix mypy
SkafteNicki Mar 21, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ universal_image_quality_index [func]
universal_spectral_angle_mapper [func]
vumichien marked this conversation as resolved.
Show resolved Hide resolved
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Borda marked this conversation as resolved.
Show resolved Hide resolved

.. autofunction:: torchmetrics.functional.universal_spectral_angle_mapper
.. autofunction:: torchmetrics.functional.spectral_angle_mapper
:noindex:


Expand Down
20 changes: 14 additions & 6 deletions tests/image/test_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from tests.helpers import seed_all
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester
from torchmetrics.functional.image.sam import universal_spectral_angle_mapper
from torchmetrics.functional.image.sam import spectral_angle_mapper
from torchmetrics.image.sam import SpectralAngleMapper

seed_all(42)
Expand All @@ -28,9 +28,9 @@

_inputs = []
for size, channel, dtype in [
(12, 1, torch.float),
(12, 3, torch.float),
(13, 3, torch.float32),
(14, 1, torch.double),
(14, 3, torch.double),
(15, 3, torch.float64),
]:
preds = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size, size, dtype=dtype)
Expand All @@ -42,7 +42,9 @@ def _sk_sam(preds, target, reduction):
# reshape to (batch_size, channel, height*width)
B, C, H, W = preds.shape
sk_preds = preds.reshape(B, C, H * W)
sk_preds = torch.movedim(sk_preds, 1, -1)
sk_target = target.reshape(B, C, H * W)
sk_target = torch.movedim(sk_target, 1, -1)
# compute arccos of cosine similarity
dot_product = (sk_preds * sk_target).sum(dim=-1)
preds_norm = sk_preds.norm(dim=-1)
Expand Down Expand Up @@ -82,7 +84,7 @@ def test_sam_functional(self, reduction, preds, target):
self.run_functional_metric_test(
preds,
target,
universal_spectral_angle_mapper,
spectral_angle_mapper,
partial(_sk_sam, reduction=reduction),
metric_args=dict(reduction=reduction),
)
Expand All @@ -94,12 +96,12 @@ def test_sam_half_cpu(self, reduction, preds, target):
preds,
target,
SpectralAngleMapper,
universal_spectral_angle_mapper,
spectral_angle_mapper,
)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda")
def test_sam_half_gpu(self, reduction, preds, target):
self.run_precision_test_gpu(preds, target, SpectralAngleMapper, universal_spectral_angle_mapper)
self.run_precision_test_gpu(preds, target, SpectralAngleMapper, spectral_angle_mapper)


def test_error_on_different_shape(metric_class=SpectralAngleMapper):
Expand All @@ -118,3 +120,9 @@ def test_error_on_invalid_type(metric_class=SpectralAngleMapper):
metric = metric_class()
with pytest.raises(TypeError):
metric(torch.randn([3, 16, 16]), torch.randn([3, 16, 16], dtype=torch.float64))


def test_error_on_grayscale_image(metric_class=SpectralAngleMapper):
metric = metric_class()
with pytest.raises(ValueError):
metric(torch.randn([16, 1, 16, 16]), torch.randn([16, 1, 16, 16]))
4 changes: 2 additions & 2 deletions torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from torchmetrics.functional.classification.stat_scores import stat_scores
from torchmetrics.functional.image.gradients import image_gradients
from torchmetrics.functional.image.psnr import peak_signal_noise_ratio
from torchmetrics.functional.image.sam import universal_spectral_angle_mapper
from torchmetrics.functional.image.sam import spectral_angle_mapper
from torchmetrics.functional.image.ssim import (
multiscale_structural_similarity_index_measure,
structural_similarity_index_measure,
Expand Down Expand Up @@ -145,7 +145,7 @@
"symmetric_mean_absolute_percentage_error",
"translation_edit_rate",
"universal_image_quality_index",
"universal_spectral_angle_mapper",
"spectral_angle_mapper",
"word_error_rate",
"char_error_rate",
"match_error_rate",
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/functional/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from torchmetrics.functional.image.gradients import image_gradients # noqa: F401
from torchmetrics.functional.image.psnr import peak_signal_noise_ratio # noqa: F401
from torchmetrics.functional.image.sam import spectral_angle_mapper # noqa: F401
from torchmetrics.functional.image.ssim import ( # noqa: F401
multiscale_structural_similarity_index_measure,
structural_similarity_index_measure,
Expand Down
25 changes: 16 additions & 9 deletions torchmetrics/functional/image/sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ def _sam_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
"Expected `preds` and `target` to have BxCxHxW shape."
f" Got preds: {preds.shape} and target: {target.shape}."
)
if (preds.shape[1] <= 1) or (target.shape[1] <= 1):
raise ValueError(
"Expected dimension `preds` and `target` is larger than 1."
vumichien marked this conversation as resolved.
Show resolved Hide resolved
f" Got preds: {preds.shape[1]} and target: {target.shape[1]}."
)
return preds, target


Expand All @@ -62,20 +67,22 @@ def _sam_compute(
- ``'none'``: no reduction will be applied
Borda marked this conversation as resolved.
Show resolved Hide resolved

Example:
>>> preds = torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42))
>>> target = torch.rand([16, 1, 16, 16], generator=torch.manual_seed(123))
>>> preds = torch.rand([16, 3, 16, 16], generator=torch.manual_seed(42))
>>> target = torch.rand([16, 3, 16, 16], generator=torch.manual_seed(123))
>>> preds, target = _sam_update(preds, target)
>>> _sam_compute(preds, target)
tensor(0.7377)
tensor(0.5943)
"""
B, C, H, W = preds.shape
preds = preds.reshape(B, C, H * W)
preds = torch.movedim(preds, 1, -1)
target = target.reshape(B, C, H * W)
target = torch.movedim(target, 1, -1)
vumichien marked this conversation as resolved.
Show resolved Hide resolved
sam_score = torch.clip(cosine_similarity(preds, target, reduction="none"), -1, 1).arccos()
return reduce(sam_score, reduction)


def universal_spectral_angle_mapper(
def spectral_angle_mapper(
preds: Tensor,
target: Tensor,
reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean",
Borda marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -101,11 +108,11 @@ def universal_spectral_angle_mapper(
If ``preds`` and ``target`` don't have ``BxCxHxW shape``.

Example:
>>> from torchmetrics.functional import universal_spectral_angle_mapper
>>> preds = torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42))
>>> target = torch.rand([16, 1, 16, 16], generator=torch.manual_seed(123))
>>> universal_spectral_angle_mapper(preds, target)
tensor(0.7377)
>>> from torchmetrics.functional import spectral_angle_mapper
>>> preds = torch.rand([16, 3, 16, 16], generator=torch.manual_seed(42))
>>> target = torch.rand([16, 3, 16, 16], generator=torch.manual_seed(123))
>>> spectral_angle_mapper(preds, target)
tensor(0.5943)

References: Roberta H. Yuhas, Alexander F. H. Goetz and Joe W. Boardman, "Discrimination among semi-arid
landscape endmembers using the Spectral Angle Mapper (SAM) algorithm" in PL, Summaries of the Third Annual JPL
Expand Down
6 changes: 3 additions & 3 deletions torchmetrics/image/sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@ class SpectralAngleMapper(Metric):
Example:
>>> import torch
>>> from torchmetrics import SpectralAngleMapper
>>> preds = torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42))
>>> target = torch.rand([16, 1, 16, 16], generator=torch.manual_seed(123))
>>> preds = torch.rand([16, 3, 16, 16], generator=torch.manual_seed(42))
>>> target = torch.rand([16, 3, 16, 16], generator=torch.manual_seed(123))
>>> sam = SpectralAngleMapper()
>>> sam(preds, target)
tensor(0.7377)
tensor(0.5943)
"""

preds: List[Tensor]
Expand Down