Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RelativeAverageSpectralError and RootMeanSquaredErrorUsingSlidingWindow #816

Merged
merged 81 commits into from
Feb 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
81 commits
Select commit Hold shift + click to select a range
9d807d6
New Metric RASE for images
Piyush-97 Jan 30, 2022
d072f32
Merge branch 'master' into enhan/rase
stancld Mar 4, 2022
29aa30a
[WIP] Add RASE + supporting RMSE_SW
stancld Mar 5, 2022
034e2f0
Merge branch 'master' into enhan/rase
SkafteNicki Mar 21, 2022
77d0817
Apply suggestions from code review
SkafteNicki Mar 21, 2022
fd943fc
Merge branch 'master' into enhan/rase
SkafteNicki Mar 21, 2022
a791e60
move requirement
SkafteNicki Mar 21, 2022
52c8c2d
add docstrings
SkafteNicki Mar 21, 2022
a712a8b
differentiable
SkafteNicki Mar 21, 2022
f492492
device placement
SkafteNicki Mar 21, 2022
12fb56a
Merge branch 'master' into enhan/rase
Borda Mar 24, 2022
16be948
Merge branch 'master' into enhan/rase
Borda Mar 25, 2022
1712dc1
updates
Borda Mar 25, 2022
8e4030d
Merge branch 'master' into enhan/rase
stancld Apr 9, 2022
5fe3795
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 9, 2022
69dcc82
Merge branch 'master' into enhan/rase
SkafteNicki Apr 11, 2022
60bb3e4
Merge branch 'master' into enhan/rase
Borda Apr 13, 2022
0a7ec0b
Merge branch 'master' into enhan/rase
SkafteNicki Apr 26, 2022
e2206a9
docs
SkafteNicki Apr 26, 2022
9a0f529
remove arg
SkafteNicki Apr 27, 2022
385b731
Merge branch 'master' into enhan/rase
stancld Apr 27, 2022
05b313c
try getting test passing
SkafteNicki Apr 27, 2022
bb04bdf
Merge branch 'master' into enhan/rase
Borda May 5, 2022
3b70886
Merge branch 'master' into enhan/rase
Borda May 5, 2022
7998d37
Merge branch 'master' into enhan/rase
Borda May 23, 2022
cf599b7
Merge branch 'master' into enhan/rase
Borda May 25, 2022
55eea02
Merge branch 'master' into enhan/rase
SkafteNicki Jun 9, 2022
48e3ddb
Merge branch 'master' into enhan/rase
stancld Oct 12, 2022
6be8c6d
Merge branch 'master' into enhan/rase
SkafteNicki Oct 24, 2022
814c321
Merge branch 'master' into enhan/rase
stancld Nov 6, 2022
fba37dc
Merge branch 'master' into enhan/rase
SkafteNicki Nov 14, 2022
e0d3bbd
Merge branch 'master' into enhan/rase
stancld Nov 15, 2022
43fb015
Update chlog
stancld Nov 15, 2022
9ef8c8d
Fix tests imports
stancld Nov 15, 2022
9d2467a
Merge branch 'master' into enhan/rase
SkafteNicki Nov 18, 2022
3da464b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 18, 2022
281d1b4
Merge branch 'master' into enhan/rase
Borda Nov 18, 2022
e8fcc8c
Fix padding (uniform filter still not working)
stancld Nov 19, 2022
0843cad
Merge branch 'master' into enhan/rase
stancld Nov 19, 2022
6885b5b
Merge branch 'master' into enhan/rase
Borda Nov 22, 2022
da22ad8
Merge branch 'master' into enhan/rase
SkafteNicki Nov 22, 2022
f1b5666
Merge branch 'master' into enhan/rase
Borda Nov 30, 2022
b886378
Merge branch 'master' into enhan/rase
stancld Dec 4, 2022
98aa555
Merge branch 'master' into enhan/rase
Borda Dec 14, 2022
ea736b0
Merge branch 'master' into enhan/rase
Borda Dec 23, 2022
603d459
Merge branch 'master' into enhan/rase
Borda Jan 30, 2023
159698d
Merge branch 'master' into enhan/rase
Borda Feb 7, 2023
a13bd49
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 7, 2023
5721330
Merge branch 'master' into enhan/rase
Borda Feb 18, 2023
03ae647
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 18, 2023
665ffa0
Merge branch 'master' into enhan/rase
mergify[bot] Feb 20, 2023
a0ef3fa
Merge branch 'master' into enhan/rase
Borda Feb 20, 2023
d823ef7
Merge branch 'master' into enhan/rase
mergify[bot] Feb 20, 2023
b128d7b
Merge branch 'master' into enhan/rase
mergify[bot] Feb 20, 2023
b7375bd
Merge branch 'master' into enhan/rase
Borda Feb 20, 2023
efb7981
Merge branch 'master' into enhan/rase
mergify[bot] Feb 20, 2023
c666dff
Merge branch 'master' into enhan/rase
mergify[bot] Feb 21, 2023
cb2b875
Merge branch 'master' into enhan/rase
Borda Feb 22, 2023
53566a9
ruff: first line split + imperative mood (#1548)
SkafteNicki Feb 24, 2023
df7fdae
Merge branch 'master' into enhan/rase
Borda Feb 24, 2023
8c3a123
Merge branch 'master' into enhan/rase
mergify[bot] Feb 24, 2023
2cfa4b4
Merge branch 'master' into enhan/rase
mergify[bot] Feb 24, 2023
fb79335
Merge branch 'master' into enhan/rase
mergify[bot] Feb 24, 2023
5c15826
Merge branch 'master' into enhan/rase
mergify[bot] Feb 24, 2023
def030e
Merge branch 'master' into enhan/rase
mergify[bot] Feb 25, 2023
bd1f2ae
Merge branch 'master' into enhan/rase
mergify[bot] Feb 25, 2023
a5c014d
Merge branch 'master' into enhan/rase
mergify[bot] Feb 25, 2023
5ee34e1
Merge branch 'master' into enhan/rase
Borda Feb 27, 2023
4dd1d2f
docs
Borda Feb 27, 2023
1446a89
compute_on_step
Borda Feb 27, 2023
5920718
Merge branch 'master' into enhan/rase
mergify[bot] Feb 27, 2023
0014a28
Merge branch 'master' into enhan/rase
mergify[bot] Feb 27, 2023
bf0ff55
Merge branch 'master' into enhan/rase
mergify[bot] Feb 28, 2023
ab520ed
Merge branch 'master' into enhan/rase
mergify[bot] Feb 28, 2023
eb8cc87
Merge branch 'master' into enhan/rase
mergify[bot] Feb 28, 2023
de09c59
try fixing tests
SkafteNicki Feb 28, 2023
a94465e
fix typing and doc issues
SkafteNicki Feb 28, 2023
b99c26e
changelog
SkafteNicki Feb 28, 2023
c395371
fixes
SkafteNicki Feb 28, 2023
ea0a21f
another link
SkafteNicki Feb 28, 2023
144fcac
try fix
SkafteNicki Feb 28, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Added`RelativeAverageSpectralError` and `RootMeanSquaredErrorUsingSlidingWindow` to image package ([#816](https://github.com/PyTorchLightning/metrics/pull/816))


- Added support for `SpecificityAtSensitivity` Metric ([#1432](https://github.com/Lightning-AI/metrics/pull/1432))


Expand Down
23 changes: 23 additions & 0 deletions docs/source/image/relative_average_spectral_error.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
.. customcarditem::
:header: Relative Average Spectral Error (RASE)
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_classification.svg
:tags: Image

.. include:: ../links.rst

######################################
Relative Average Spectral Error (RASE)
######################################

Module Interface
________________

.. autoclass:: torchmetrics.RelativeAverageSpectralError
:noindex:
:exclude-members: update, compute

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.relative_average_spectral_error
:noindex:
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
.. customcarditem::
:header: Root Mean Squared Error Using Sliding Window
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_classification.svg
:tags: Image

.. include:: ../links.rst

############################################
Root Mean Squared Error Using Sliding Window
############################################

Module Interface
________________

.. autoclass:: torchmetrics.RootMeanSquaredErrorUsingSlidingWindow
:noindex:
:exclude-members: update, compute

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.root_mean_squared_error_using_sliding_window
:noindex:
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
.. _MultiScaleSSIM: https://ece.uwaterloo.ca/~z70wang/publications/msssim.pdf
.. _UniversalImageQualityIndex: https://ieeexplore.ieee.org/document/995823
.. _SpectralDistortionIndex: https://www.semanticscholar.org/paper/Multispectral-and-panchromatic-data-fusion-without-Alparone-Aiazzi/b6db12e3785326577cb95fd743fecbf5bc66c7c9
.. _RelativeAverageSpectralError: https://www.semanticscholar.org/paper/Data-Fusion.-Definitions-and-Architectures-Fusion-Wald/51b2b81e5124b3bb7ec53517a5dd64d8e348cadf
.. _WMAPE: https://en.wikipedia.org/wiki/WMAPE
.. _CER: https://rechtsprechung-im-ostseeraum.archiv.uni-greifswald.de/word-error-rate-character-error-rate-how-to-evaluate-a-model
.. _MER: https://www.isca-speech.org/archive_v0/archive_papers/interspeech_2004/i04_2765.pdf
Expand Down
1 change: 1 addition & 0 deletions requirements/image_test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
scikit-image >0.17.1, <=0.19.3
kornia >=0.6.7, <0.6.11
pytorch-msssim ==0.2.1
sewar >=0.4.4, <=0.4.5
4 changes: 4 additions & 0 deletions src/torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
ErrorRelativeGlobalDimensionlessSynthesis,
MultiScaleStructuralSimilarityIndexMeasure,
PeakSignalNoiseRatio,
RelativeAverageSpectralError,
RootMeanSquaredErrorUsingSlidingWindow,
SpectralAngleMapper,
SpectralDistortionIndex,
StructuralSimilarityIndexMeasure,
Expand Down Expand Up @@ -161,6 +163,7 @@
"PeakSignalNoiseRatio",
"R2Score",
"Recall",
"RelativeAverageSpectralError",
"RetrievalFallOut",
"RetrievalHitRate",
"RetrievalMAP",
Expand All @@ -172,6 +175,7 @@
"RetrievalPrecisionRecallCurve",
"RetrievalRecallAtFixedPrecision",
"ROC",
"RootMeanSquaredErrorUsingSlidingWindow",
"SacreBLEUScore",
"SignalDistortionRatio",
"ScaleInvariantSignalDistortionRatio",
Expand Down
4 changes: 4 additions & 0 deletions src/torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
from torchmetrics.functional.image.ergas import error_relative_global_dimensionless_synthesis
from torchmetrics.functional.image.gradients import image_gradients
from torchmetrics.functional.image.psnr import peak_signal_noise_ratio
from torchmetrics.functional.image.rase import relative_average_spectral_error
from torchmetrics.functional.image.rmse_sw import root_mean_squared_error_using_sliding_window
from torchmetrics.functional.image.sam import spectral_angle_mapper
from torchmetrics.functional.image.ssim import (
multiscale_structural_similarity_index_measure,
Expand Down Expand Up @@ -155,6 +157,7 @@
"peak_signal_noise_ratio",
"r2_score",
"recall",
"relative_average_spectral_error",
"retrieval_average_precision",
"retrieval_fall_out",
"retrieval_hit_rate",
Expand All @@ -165,6 +168,7 @@
"retrieval_reciprocal_rank",
"retrieval_precision_recall_curve",
"roc",
"root_mean_squared_error_using_sliding_window",
"rouge_score",
"sacre_bleu_score",
"signal_distortion_ratio",
Expand Down
2 changes: 2 additions & 0 deletions src/torchmetrics/functional/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from torchmetrics.functional.image.ergas import error_relative_global_dimensionless_synthesis # noqa: F401
from torchmetrics.functional.image.gradients import image_gradients # noqa: F401
from torchmetrics.functional.image.psnr import peak_signal_noise_ratio # noqa: F401
from torchmetrics.functional.image.rase import relative_average_spectral_error # noqa: F401
from torchmetrics.functional.image.rmse_sw import root_mean_squared_error_using_sliding_window # noqa: F401
from torchmetrics.functional.image.sam import spectral_angle_mapper # noqa: F401
from torchmetrics.functional.image.ssim import ( # noqa: F401
multiscale_structural_similarity_index_measure,
Expand Down
97 changes: 77 additions & 20 deletions src/torchmetrics/functional/image/helper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Sequence, Union
from typing import Sequence, Tuple, Union

import torch
from torch import Tensor
Expand Down Expand Up @@ -57,6 +57,79 @@ def _gaussian_kernel_2d(
return kernel.expand(channel, 1, kernel_size[0], kernel_size[1])


def _uniform_weight_bias_conv2d(inputs: Tensor, window_size: int) -> Tuple[Tensor, Tensor]:
"""Construct uniform weight and bias for a 2d convolution.

Args:
inputs: Input image
window_size: size of convolutional kernel

Return:
The weight and bias for 2d convolution
"""
kernel_weight = torch.ones(1, 1, window_size, window_size, dtype=inputs.dtype, device=inputs.device)
kernel_weight /= window_size**2
kernel_bias = torch.zeros(1, dtype=inputs.dtype, device=inputs.device)
return kernel_weight, kernel_bias


def _single_dimension_pad(inputs: Tensor, dim: int, pad: int, outer_pad: int = 0) -> Tensor:
"""Apply single-dimension reflection padding to match scipy implementation.

Args:
inputs: Input image
dim: A dimension the image should be padded over
pad: Number of pads
outer_pad: Number of outer pads

Return:
Image padded over a single dimension
"""
_max = inputs.shape[dim]
x = torch.index_select(inputs, dim, torch.arange(pad - 1, -1, -1).to(inputs.device))
y = torch.index_select(inputs, dim, torch.arange(_max - 1, _max - pad - outer_pad, -1).to(inputs.device))
return torch.cat((x, inputs, y), dim)


def _reflection_pad_2d(inputs: Tensor, pad: int, outer_pad: int = 0) -> Tensor:
"""Apply reflection padding to the input image.

Args:
inputs: Input image
pad: Number of pads
outer_pad: Number of outer pads

Return:
Padded image
"""
for dim in [2, 3]:
inputs = _single_dimension_pad(inputs, dim, pad, outer_pad)
return inputs


def _uniform_filter(inputs: Tensor, window_size: int) -> Tensor:
"""Applies uniform filtew with a window of a given size over the input image.

Args:
inputs: Input image
window_size: Sliding window used for rmse calculation

Return:
Image transformed with the uniform input
"""
inputs = _reflection_pad_2d(inputs, window_size // 2, window_size % 2)
kernel_weight, kernel_bias = _uniform_weight_bias_conv2d(inputs, window_size)
# Iterate over channels
inputs = torch.cat(
[
F.conv2d(inputs[:, channel].unsqueeze(1), kernel_weight, kernel_bias, padding=0)
for channel in range(inputs.shape[1])
],
dim=1,
)
return inputs


def _gaussian_kernel_3d(
channel: int, kernel_size: Sequence[int], sigma: Sequence[float], dtype: torch.dtype, device: torch.device
) -> Tensor:
Expand All @@ -80,23 +153,6 @@ def _gaussian_kernel_3d(
return kernel.expand(channel, 1, kernel_size[0], kernel_size[1], kernel_size[2])


def _single_dimension_pad(inputs: Tensor, dim: int, pad: int) -> Tensor:
"""Reflective padding of input along a specific dimension.

Args:
inputs: tensor to pad
dim: dimension to pad along
pad: amount of padding to add

Returns:
padded input
"""
_max = inputs.shape[dim] - 2
x = torch.index_select(inputs, dim, torch.arange(pad, 0, -1, device=inputs.device))
y = torch.index_select(inputs, dim, torch.arange(_max, _max - pad, -1, device=inputs.device))
return torch.cat((x, inputs, y), dim)


def _reflection_pad_3d(inputs: Tensor, pad_h: int, pad_w: int, pad_d: int) -> Tensor:
"""Reflective padding of 3d input.

Expand All @@ -113,8 +169,9 @@ def _reflection_pad_3d(inputs: Tensor, pad_h: int, pad_w: int, pad_d: int) -> Te
inputs = F.pad(inputs, (pad_h, pad_h, pad_w, pad_w, pad_d, pad_d), mode="reflect")
else:
rank_zero_warn(
"An older version of pyTorch is used. For optimal speed, please upgrade to at least pyTorch 1.10."
"An older version of pyTorch is used."
" For optimal speed, please upgrade to at least PyTorch v1.10 or higher."
)
for dim, pad in enumerate([pad_h, pad_w, pad_d]):
inputs = _single_dimension_pad(inputs, dim + 2, pad)
inputs = _single_dimension_pad(inputs, dim + 2, pad, outer_pad=1)
return inputs
101 changes: 101 additions & 0 deletions src/torchmetrics/functional/image/rase.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Tuple

import torch
from torch import Tensor

from torchmetrics.functional.image.helper import _uniform_filter
from torchmetrics.functional.image.rmse_sw import _rmse_sw_compute, _rmse_sw_update


def _rase_update(
preds: Tensor, target: Tensor, window_size: int, rmse_map: Tensor, target_sum: Tensor, total_images: Tensor
) -> Tuple[Tensor, Tensor, Tensor]:
"""Calculates the sum of RMSE map values for the batch of examples and update intermediate states.

Args:
preds: Deformed image
target: Ground truth image
window_size: Sliding window used for RMSE calculation
rmse_map: Sum of RMSE map values over all examples
target_sum: target...
total_images: Total number of images

Return:
Intermediate state of RMSE map
Updated total number of already processed images
"""
_, rmse_map, total_images = _rmse_sw_update(
preds, target, window_size, rmse_val_sum=None, rmse_map=rmse_map, total_images=total_images
)
target_sum += torch.sum(_uniform_filter(target, window_size) / (window_size**2), dim=0)
return rmse_map, target_sum, total_images


def _rase_compute(rmse_map: Tensor, target_sum: Tensor, total_images: Tensor, window_size: int) -> Tensor:
"""Compute RASE.

Args:
rmse_map: Sum of RMSE map values over all examples
target_sum: target...
total_images: Total number of images.
window_size: Sliding window used for rmse calculation

Return:
Relative Average Spectral Error (RASE)
"""
_, rmse_map = _rmse_sw_compute(rmse_val_sum=None, rmse_map=rmse_map, total_images=total_images)
target_mean = target_sum / total_images
target_mean = target_mean.mean(0) # mean over image channels
rase_map = 100 / target_mean * torch.sqrt(torch.mean(rmse_map**2, 0))
crop_slide = round(window_size / 2)

return torch.mean(rase_map[crop_slide:-crop_slide, crop_slide:-crop_slide])


def relative_average_spectral_error(preds: Tensor, target: Tensor, window_size: int = 8) -> Tensor:
"""Computes Relative Average Spectral Error (RASE) (RelativeAverageSpectralError_).

Args:
preds: Deformed image
target: Ground truth image
window_size: Sliding window used for rmse calculation

Return:
Relative Average Spectral Error (RASE)

Example:
>>> from torchmetrics.functional import relative_average_spectral_error
>>> g = torch.manual_seed(22)
>>> preds = torch.rand(4, 3, 16, 16)
>>> target = torch.rand(4, 3, 16, 16)
>>> relative_average_spectral_error(preds, target)
tensor(5114.6641)

Raises:
ValueError: If ``window_size`` is not a positive integer.
"""
if not isinstance(window_size, int) or isinstance(window_size, int) and window_size < 1:
raise ValueError("Argument `window_size` is expected to be a positive integer.")

img_shape = target.shape[1:] # [num_channels, width, height]
rmse_map = torch.zeros(img_shape, dtype=target.dtype, device=target.device)
target_sum = torch.zeros(img_shape, dtype=target.dtype, device=target.device)
total_images = torch.tensor(0.0, device=target.device)

rmse_map, target_sum, total_images = _rase_update(preds, target, window_size, rmse_map, target_sum, total_images)
rase = _rase_compute(rmse_map, target_sum, total_images, window_size)
return rase
Loading