Skip to content

Commit

Permalink
Refactor: allow reduction None (Lightning-AI#891)
Browse files Browse the repository at this point in the history
* allow reduction None
* Literal

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and gianscarpe committed Apr 21, 2022
1 parent 3f58b2e commit 7288108
Show file tree
Hide file tree
Showing 19 changed files with 82 additions and 49 deletions.
3 changes: 2 additions & 1 deletion torchmetrics/classification/jaccard.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import torch
from torch import Tensor
from typing_extensions import Literal

from torchmetrics.classification.confusion_matrix import ConfusionMatrix
from torchmetrics.functional.classification.jaccard import _jaccard_from_confmat
Expand Down Expand Up @@ -89,7 +90,7 @@ def __init__(
absent_score: float = 0.0,
threshold: float = 0.5,
multilabel: bool = False,
reduction: str = "elementwise_mean",
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
compute_on_step: Optional[bool] = None,
**kwargs: Dict[str, Any],
) -> None:
Expand Down
3 changes: 2 additions & 1 deletion torchmetrics/classification/kl_divergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import torch
from torch import Tensor
from typing_extensions import Literal

from torchmetrics.functional.classification.kl_divergence import _kld_compute, _kld_update
from torchmetrics.metric import Metric
Expand Down Expand Up @@ -79,7 +80,7 @@ class KLDivergence(Metric):
def __init__(
self,
log_prob: bool = False,
reduction: Optional[str] = "mean",
reduction: Literal["mean", "sum", "none", None] = "mean",
compute_on_step: Optional[bool] = None,
**kwargs: Dict[str, Any],
) -> None:
Expand Down
5 changes: 3 additions & 2 deletions torchmetrics/functional/classification/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import torch
from torch import Tensor
from typing_extensions import Literal

from torchmetrics.utilities.data import to_categorical
from torchmetrics.utilities.distributed import reduce
Expand Down Expand Up @@ -64,7 +65,7 @@ def dice_score(
bg: bool = False,
nan_score: float = 0.0,
no_fg_score: float = 0.0,
reduction: str = "elementwise_mean",
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
) -> Tensor:
"""Compute dice score from prediction scores.
Expand All @@ -78,7 +79,7 @@ def dice_score(
- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied
- ``'none'`` or ``None``: no reduction will be applied
Return:
Tensor containing dice score
Expand Down
9 changes: 5 additions & 4 deletions torchmetrics/functional/classification/jaccard.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import torch
from torch import Tensor
from typing_extensions import Literal

from torchmetrics.functional.classification.confusion_matrix import _confusion_matrix_update
from torchmetrics.utilities.data import get_num_classes
Expand All @@ -26,7 +27,7 @@ def _jaccard_from_confmat(
num_classes: int,
ignore_index: Optional[int] = None,
absent_score: float = 0.0,
reduction: str = "elementwise_mean",
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
) -> Tensor:
"""Computes the intersection over union from confusion matrix.
Expand All @@ -41,7 +42,7 @@ def _jaccard_from_confmat(
- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied
- ``'none'`` or ``None``: no reduction will be applied
"""

# Remove the ignored class index from the scores.
Expand Down Expand Up @@ -73,7 +74,7 @@ def jaccard_index(
ignore_index: Optional[int] = None,
absent_score: float = 0.0,
threshold: float = 0.5,
reduction: str = "elementwise_mean",
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
) -> Tensor:
r"""
Computes `Jaccard index`_
Expand Down Expand Up @@ -113,7 +114,7 @@ def jaccard_index(
- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied
- ``'none'`` or ``None``: no reduction will be applied
Return:
IoU score: Tensor containing single value if reduction is
Expand Down
9 changes: 6 additions & 3 deletions torchmetrics/functional/classification/kl_divergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Tuple
from typing import Tuple

import torch
from torch import Tensor
from typing_extensions import Literal

from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.data import METRIC_EPS
Expand Down Expand Up @@ -47,7 +48,7 @@ def _kld_update(p: Tensor, q: Tensor, log_prob: bool) -> Tuple[Tensor, int]:
return measures, total


def _kld_compute(measures: Tensor, total: Tensor, reduction: Optional[str] = "mean") -> Tensor:
def _kld_compute(measures: Tensor, total: Tensor, reduction: Literal["mean", "sum", "none", None] = "mean") -> Tensor:
"""Computes the KL divergenece based on the type of reduction.
Args:
Expand Down Expand Up @@ -77,7 +78,9 @@ def _kld_compute(measures: Tensor, total: Tensor, reduction: Optional[str] = "me
return measures / total


def kl_divergence(p: Tensor, q: Tensor, log_prob: bool = False, reduction: Optional[str] = "mean") -> Tensor:
def kl_divergence(
p: Tensor, q: Tensor, log_prob: bool = False, reduction: Literal["mean", "sum", "none", None] = "mean"
) -> Tensor:
r"""Computes `KL divergence`_
.. math::
Expand Down
9 changes: 5 additions & 4 deletions torchmetrics/functional/image/psnr.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import torch
from torch import Tensor, tensor
from typing_extensions import Literal

from torchmetrics.utilities import rank_zero_warn, reduce

Expand All @@ -24,7 +25,7 @@ def _psnr_compute(
n_obs: Tensor,
data_range: Tensor,
base: float = 10.0,
reduction: str = "elementwise_mean",
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
) -> Tensor:
"""Computes peak signal-to-noise ratio.
Expand All @@ -39,7 +40,7 @@ def _psnr_compute(
- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied
- ``'none'`` or ``None``: no reduction will be applied
Example:
>>> preds = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
Expand Down Expand Up @@ -96,7 +97,7 @@ def peak_signal_noise_ratio(
target: Tensor,
data_range: Optional[float] = None,
base: float = 10.0,
reduction: str = "elementwise_mean",
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
dim: Optional[Union[int, Tuple[int, ...]]] = None,
) -> Tensor:
"""Computes the peak signal-to-noise ratio.
Expand All @@ -112,7 +113,7 @@ def peak_signal_noise_ratio(
- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied
- ``'none'`` or None``: no reduction will be applied
dim:
Dimensions to reduce PSNR scores over provided as either an integer or a list of integers. Default is
Expand Down
18 changes: 9 additions & 9 deletions torchmetrics/functional/image/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _ssim_compute(
target: Tensor,
kernel_size: Sequence[int] = (11, 11),
sigma: Sequence[float] = (1.5, 1.5),
reduction: str = "elementwise_mean",
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
data_range: Optional[float] = None,
k1: float = 0.01,
k2: float = 0.03,
Expand All @@ -68,7 +68,7 @@ def _ssim_compute(
- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied
- ``'none'`` or ``None``: no reduction will be applied
data_range: Range of the image. If ``None``, it is determined from the image (max - min)
k1: Parameter of SSIM.
Expand Down Expand Up @@ -140,7 +140,7 @@ def structural_similarity_index_measure(
target: Tensor,
kernel_size: Sequence[int] = (11, 11),
sigma: Sequence[float] = (1.5, 1.5),
reduction: str = "elementwise_mean",
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
data_range: Optional[float] = None,
k1: float = 0.01,
k2: float = 0.03,
Expand All @@ -156,7 +156,7 @@ def structural_similarity_index_measure(
- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied
- ``'none'`` or ``None``: no reduction will be applied
data_range: Range of the image. If ``None``, it is determined from the image (max - min)
k1: Parameter of SSIM.
Expand Down Expand Up @@ -193,7 +193,7 @@ def _get_normalized_sim_and_cs(
target: Tensor,
kernel_size: Sequence[int] = (11, 11),
sigma: Sequence[float] = (1.5, 1.5),
reduction: str = "elementwise_mean",
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
data_range: Optional[float] = None,
k1: float = 0.01,
k2: float = 0.03,
Expand All @@ -213,7 +213,7 @@ def _multiscale_ssim_compute(
target: Tensor,
kernel_size: Sequence[int] = (11, 11),
sigma: Sequence[float] = (1.5, 1.5),
reduction: str = "elementwise_mean",
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
data_range: Optional[float] = None,
k1: float = 0.01,
k2: float = 0.03,
Expand All @@ -238,7 +238,7 @@ def _multiscale_ssim_compute(
- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied
- ``'none'`` or ``None``: no reduction will be applied
data_range: Range of the image. If ``None``, it is determined from the image (max - min)
k1: Parameter of structural similarity index measure.
Expand Down Expand Up @@ -304,7 +304,7 @@ def multiscale_structural_similarity_index_measure(
target: Tensor,
kernel_size: Sequence[int] = (11, 11),
sigma: Sequence[float] = (1.5, 1.5),
reduction: str = "elementwise_mean",
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
data_range: Optional[float] = None,
k1: float = 0.01,
k2: float = 0.03,
Expand All @@ -323,7 +323,7 @@ def multiscale_structural_similarity_index_measure(
- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied
- ``'none'`` or ``None``: no reduction will be applied
data_range: Range of the image. If ``None``, it is determined from the image (max - min)
k1: Parameter of structural similarity index measure.
Expand Down
8 changes: 4 additions & 4 deletions torchmetrics/functional/image/uqi.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _uqi_compute(
target: Tensor,
kernel_size: Sequence[int] = (11, 11),
sigma: Sequence[float] = (1.5, 1.5),
reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean",
reduction: Optional[Literal["elementwise_mean", "sum", "none"]] = "elementwise_mean",
data_range: Optional[float] = None,
return_contrast_sensitivity: bool = False,
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
Expand All @@ -66,7 +66,7 @@ def _uqi_compute(
- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied
- ``'none'`` or ``None``: no reduction will be applied
data_range: Range of the image. If ``None``, it is determined from the image (max - min)
Expand Down Expand Up @@ -128,7 +128,7 @@ def universal_image_quality_index(
target: Tensor,
kernel_size: Sequence[int] = (11, 11),
sigma: Sequence[float] = (1.5, 1.5),
reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean",
reduction: Optional[Literal["elementwise_mean", "sum", "none"]] = "elementwise_mean",
data_range: Optional[float] = None,
) -> Tensor:
"""Universal Image Quality Index.
Expand All @@ -142,7 +142,7 @@ def universal_image_quality_index(
- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied
- ``'none'`` or ``None``: no reduction will be applied
data_range: Range of the image. If ``None``, it is determined from the image (max - min)
Expand Down
6 changes: 5 additions & 1 deletion torchmetrics/functional/pairwise/cosine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import torch
from torch import Tensor
from typing_extensions import Literal

from torchmetrics.functional.pairwise.helpers import _check_input, _reduce_distance_matrix

Expand Down Expand Up @@ -43,7 +44,10 @@ def _pairwise_cosine_similarity_update(


def pairwise_cosine_similarity(
x: Tensor, y: Optional[Tensor] = None, reduction: Optional[str] = None, zero_diagonal: Optional[bool] = None
x: Tensor,
y: Optional[Tensor] = None,
reduction: Literal["mean", "sum", "none", None] = None,
zero_diagonal: Optional[bool] = None,
) -> Tensor:
r"""
Calculates pairwise cosine similarity:
Expand Down
6 changes: 5 additions & 1 deletion torchmetrics/functional/pairwise/euclidean.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Optional

from torch import Tensor
from typing_extensions import Literal

from torchmetrics.functional.pairwise.helpers import _check_input, _reduce_distance_matrix

Expand All @@ -38,7 +39,10 @@ def _pairwise_euclidean_distance_update(


def pairwise_euclidean_distance(
x: Tensor, y: Optional[Tensor] = None, reduction: Optional[str] = None, zero_diagonal: Optional[bool] = None
x: Tensor,
y: Optional[Tensor] = None,
reduction: Literal["mean", "sum", "none", None] = None,
zero_diagonal: Optional[bool] = None,
) -> Tensor:
r"""
Calculates pairwise euclidean distances:
Expand Down
6 changes: 5 additions & 1 deletion torchmetrics/functional/pairwise/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Optional

from torch import Tensor
from typing_extensions import Literal

from torchmetrics.functional.pairwise.helpers import _check_input, _reduce_distance_matrix

Expand All @@ -37,7 +38,10 @@ def _pairwise_linear_similarity_update(


def pairwise_linear_similarity(
x: Tensor, y: Optional[Tensor] = None, reduction: Optional[str] = None, zero_diagonal: Optional[bool] = None
x: Tensor,
y: Optional[Tensor] = None,
reduction: Literal["mean", "sum", "none", None] = None,
zero_diagonal: Optional[bool] = None,
) -> Tensor:
r"""
Calculates pairwise linear similarity:
Expand Down
6 changes: 5 additions & 1 deletion torchmetrics/functional/pairwise/manhattan.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Optional

from torch import Tensor
from typing_extensions import Literal

from torchmetrics.functional.pairwise.helpers import _check_input, _reduce_distance_matrix

Expand All @@ -37,7 +38,10 @@ def _pairwise_manhattan_distance_update(


def pairwise_manhattan_distance(
x: Tensor, y: Optional[Tensor] = None, reduction: Optional[str] = None, zero_diagonal: Optional[bool] = None
x: Tensor,
y: Optional[Tensor] = None,
reduction: Literal["mean", "sum", "none", None] = None,
zero_diagonal: Optional[bool] = None,
) -> Tensor:
r"""
Calculates pairwise manhattan distance:
Expand Down
Loading

0 comments on commit 7288108

Please sign in to comment.