Skip to content

Commit

Permalink
Add plotting 19/n (#1662)
Browse files Browse the repository at this point in the history
* fixes
* typing

---------

Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
Co-authored-by: Jirka <jirka.borovec@seznam.cz>
  • Loading branch information
3 people authored Mar 31, 2023
1 parent b5ac44e commit c553325
Show file tree
Hide file tree
Showing 83 changed files with 285 additions and 75 deletions.
4 changes: 2 additions & 2 deletions src/torchmetrics/audio/pesq.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ class PerceptualEvaluationSpeechQuality(Metric):
full_state_update: bool = False
is_differentiable: bool = False
higher_is_better: bool = True
plot_lower_bound = 1.0
plot_upper_bound = 4.5
plot_lower_bound: float = 1.0
plot_upper_bound: float = 4.5

def __init__(
self,
Expand Down
4 changes: 2 additions & 2 deletions src/torchmetrics/audio/pit.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ class PermutationInvariantTraining(Metric):
is_differentiable: bool = True
sum_pit_metric: Tensor
total: Tensor
plot_lower_bound = -10.0
plot_upper_bound = 1.0
plot_lower_bound: float = -10.0
plot_upper_bound: float = 1.0

def __init__(
self,
Expand Down
8 changes: 4 additions & 4 deletions src/torchmetrics/audio/sdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ class SignalDistortionRatio(Metric):
full_state_update: bool = False
is_differentiable: bool = True
higher_is_better: bool = True
plot_lower_bound = -20.0
plot_upper_bound = 1.0
plot_lower_bound: float = -20.0
plot_upper_bound: float = 1.0

def __init__(
self,
Expand Down Expand Up @@ -195,8 +195,8 @@ class ScaleInvariantSignalDistortionRatio(Metric):
higher_is_better = True
sum_si_sdr: Tensor
total: Tensor
plot_lower_bound = -40.0
plot_upper_bound = 20.0
plot_lower_bound: float = -40.0
plot_upper_bound: float = 20.0

def __init__(
self,
Expand Down
8 changes: 4 additions & 4 deletions src/torchmetrics/audio/snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ class SignalNoiseRatio(Metric):
higher_is_better: bool = True
sum_snr: Tensor
total: Tensor
plot_lower_bound = -20.0
plot_upper_bound = 5.0
plot_lower_bound: float = -20.0
plot_upper_bound: float = 5.0

def __init__(
self,
Expand Down Expand Up @@ -164,8 +164,8 @@ class ScaleInvariantSignalNoiseRatio(Metric):
sum_si_snr: Tensor
total: Tensor
higher_is_better = True
plot_lower_bound = -20.0
plot_upper_bound = 10.0
plot_lower_bound: float = -20.0
plot_upper_bound: float = 10.0

def __init__(
self,
Expand Down
4 changes: 2 additions & 2 deletions src/torchmetrics/audio/stoi.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ class ShortTimeObjectiveIntelligibility(Metric):
full_state_update: bool = False
is_differentiable: bool = False
higher_is_better: bool = True
plot_lower_bound = -20.0
plot_upper_bound = 5.0
plot_lower_bound: float = -20.0
plot_upper_bound: float = 5.0

def __init__(
self,
Expand Down
16 changes: 8 additions & 8 deletions src/torchmetrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ class BinaryAccuracy(BinaryStatScores):
is_differentiable = False
higher_is_better = True
full_state_update: bool = False
plot_lower_bound = 0.0
plot_upper_bound = 1.0
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0

def compute(self) -> Tensor:
"""Compute accuracy based on inputs passed in to ``update`` previously."""
Expand Down Expand Up @@ -237,9 +237,9 @@ class MulticlassAccuracy(MulticlassStatScores):
is_differentiable = False
higher_is_better = True
full_state_update: bool = False
plot_lower_bound = 0.0
plot_upper_bound = 1.0
plot_legend_name = "Class"
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0
plot_legend_name: str = "Class"

def compute(self) -> Tensor:
"""Compute accuracy based on inputs passed in to ``update`` previously."""
Expand Down Expand Up @@ -384,9 +384,9 @@ class MultilabelAccuracy(MultilabelStatScores):
is_differentiable = False
higher_is_better = True
full_state_update: bool = False
plot_lower_bound = 0.0
plot_upper_bound = 1.0
plot_legend_name = "Label"
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0
plot_legend_name: str = "Label"

def compute(self) -> Tensor:
"""Compute accuracy based on inputs passed in to ``update`` previously."""
Expand Down
16 changes: 8 additions & 8 deletions src/torchmetrics/classification/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ class BinaryAUROC(BinaryPrecisionRecallCurve):
is_differentiable: bool = False
higher_is_better: Optional[bool] = None
full_state_update: bool = False
plot_lower_bound = 0.0
plot_upper_bound = 1.0
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0

def __init__(
self,
Expand Down Expand Up @@ -239,9 +239,9 @@ class MulticlassAUROC(MulticlassPrecisionRecallCurve):
is_differentiable: bool = False
higher_is_better: Optional[bool] = None
full_state_update: bool = False
plot_lower_bound = 0.0
plot_upper_bound = 1.0
plot_legend_name = "Class"
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0
plot_legend_name: str = "Class"

def __init__(
self,
Expand Down Expand Up @@ -386,9 +386,9 @@ class MultilabelAUROC(MultilabelPrecisionRecallCurve):
is_differentiable: bool = False
higher_is_better: Optional[bool] = None
full_state_update: bool = False
plot_lower_bound = 0.0
plot_upper_bound = 1.0
plot_legend_name = "Label"
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0
plot_legend_name: str = "Label"

def __init__(
self,
Expand Down
16 changes: 8 additions & 8 deletions src/torchmetrics/classification/average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ class BinaryAveragePrecision(BinaryPrecisionRecallCurve):
is_differentiable: bool = False
higher_is_better: Optional[bool] = None
full_state_update: bool = False
plot_lower_bound = 0.0
plot_upper_bound = 1.0
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0

def compute(self) -> Tensor:
"""Compute metric."""
Expand Down Expand Up @@ -238,9 +238,9 @@ class MulticlassAveragePrecision(MulticlassPrecisionRecallCurve):
is_differentiable: bool = False
higher_is_better: Optional[bool] = None
full_state_update: bool = False
plot_lower_bound = 0.0
plot_upper_bound = 1.0
plot_legend_name = "Class"
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0
plot_legend_name: str = "Class"

def __init__(
self,
Expand Down Expand Up @@ -390,9 +390,9 @@ class MultilabelAveragePrecision(MultilabelPrecisionRecallCurve):
is_differentiable: bool = False
higher_is_better: Optional[bool] = None
full_state_update: bool = False
plot_lower_bound = 0.0
plot_upper_bound = 1.0
plot_legend_name = "Label"
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0
plot_legend_name: str = "Label"

def __init__(
self,
Expand Down
9 changes: 5 additions & 4 deletions src/torchmetrics/classification/calibration_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ class BinaryCalibrationError(Metric):
is_differentiable: bool = False
higher_is_better: bool = False
full_state_update: bool = False
plot_lower_bound = 0.0
plot_upper_bound = 1.0
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0

def __init__(
self,
Expand Down Expand Up @@ -245,8 +245,9 @@ class MulticlassCalibrationError(Metric):
is_differentiable: bool = False
higher_is_better: bool = False
full_state_update: bool = False
plot_lower_bound = 0.0
plot_upper_bound = 1.0
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0
plot_legend_name: str = "Class"

def __init__(
self,
Expand Down
9 changes: 5 additions & 4 deletions src/torchmetrics/classification/cohen_kappa.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ class labels.
is_differentiable: bool = False
higher_is_better: bool = True
full_state_update: bool = False
plot_lower_bound = 0.0
plot_upper_bound = 1.0
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0

def __init__(
self,
Expand Down Expand Up @@ -215,8 +215,9 @@ class labels.
is_differentiable: bool = False
higher_is_better: bool = True
full_state_update: bool = False
plot_lower_bound = 0.0
plot_upper_bound = 1.0
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0
plot_legend_name: str = "Class"

def __init__(
self,
Expand Down
3 changes: 3 additions & 0 deletions src/torchmetrics/classification/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ class Dice(Metric):
is_differentiable: bool = False
higher_is_better: bool = True
full_state_update: bool = False
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0
plot_legend_name: str = "Class"

@no_type_check
def __init__(
Expand Down
6 changes: 6 additions & 0 deletions src/torchmetrics/classification/exact_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ class MulticlassExactMatch(Metric):
is_differentiable = False
higher_is_better = True
full_state_update: bool = False
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0
plot_legend_name: str = "Class"

def __init__(
self,
Expand Down Expand Up @@ -253,6 +256,9 @@ class MultilabelExactMatch(Metric):
is_differentiable = False
higher_is_better = True
full_state_update: bool = False
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0
plot_legend_name: str = "Label"

def __init__(
self,
Expand Down
16 changes: 16 additions & 0 deletions src/torchmetrics/classification/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ class BinaryFBetaScore(BinaryStatScores):
is_differentiable: bool = False
higher_is_better: Optional[bool] = True
full_state_update: bool = False
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0

def __init__(
self,
Expand Down Expand Up @@ -280,6 +282,9 @@ class MulticlassFBetaScore(MulticlassStatScores):
is_differentiable: bool = False
higher_is_better: Optional[bool] = True
full_state_update: bool = False
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0
plot_legend_name: str = "Class"

def __init__(
self,
Expand Down Expand Up @@ -451,6 +456,9 @@ class MultilabelFBetaScore(MultilabelStatScores):
is_differentiable: bool = False
higher_is_better: Optional[bool] = True
full_state_update: bool = False
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0
plot_legend_name: str = "Label"

def __init__(
self,
Expand Down Expand Up @@ -592,6 +600,8 @@ class BinaryF1Score(BinaryFBetaScore):
is_differentiable: bool = False
higher_is_better: Optional[bool] = True
full_state_update: bool = False
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0

def __init__(
self,
Expand Down Expand Up @@ -752,6 +762,9 @@ class MulticlassF1Score(MulticlassFBetaScore):
is_differentiable: bool = False
higher_is_better: Optional[bool] = True
full_state_update: bool = False
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0
plot_legend_name: str = "Class"

def __init__(
self,
Expand Down Expand Up @@ -912,6 +925,9 @@ class MultilabelF1Score(MultilabelFBetaScore):
is_differentiable: bool = False
higher_is_better: Optional[bool] = True
full_state_update: bool = False
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0
plot_legend_name: str = "Label"

def __init__(
self,
Expand Down
4 changes: 4 additions & 0 deletions src/torchmetrics/classification/group_fairness.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ class BinaryGroupStatRates(_AbstractGroupStatScores):
is_differentiable = False
higher_is_better = False
full_state_update: bool = False
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0

def __init__(
self,
Expand Down Expand Up @@ -200,6 +202,8 @@ class BinaryFairness(_AbstractGroupStatScores):
is_differentiable = False
higher_is_better = False
full_state_update: bool = False
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0

def __init__(
self,
Expand Down
8 changes: 8 additions & 0 deletions src/torchmetrics/classification/hamming.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ class BinaryHammingDistance(BinaryStatScores):
is_differentiable: bool = False
higher_is_better: bool = False
full_state_update: bool = False
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0

def compute(self) -> Tensor:
"""Compute metric."""
Expand Down Expand Up @@ -248,6 +250,9 @@ class MulticlassHammingDistance(MulticlassStatScores):
is_differentiable: bool = False
higher_is_better: bool = False
full_state_update: bool = False
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0
plot_legend_name: str = "Class"

def compute(self) -> Tensor:
"""Compute metric."""
Expand Down Expand Up @@ -393,6 +398,9 @@ class MultilabelHammingDistance(MultilabelStatScores):
is_differentiable: bool = False
higher_is_better: bool = False
full_state_update: bool = False
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0
plot_legend_name: str = "Label"

def compute(self) -> Tensor:
"""Compute metric."""
Expand Down
5 changes: 5 additions & 0 deletions src/torchmetrics/classification/hinge.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ class BinaryHingeLoss(Metric):
is_differentiable: bool = True
higher_is_better: bool = False
full_state_update: bool = False
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0

def __init__(
self,
Expand Down Expand Up @@ -219,6 +221,9 @@ class MulticlassHingeLoss(Metric):
is_differentiable: bool = True
higher_is_better: bool = False
full_state_update: bool = False
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0
plot_legend_name: str = "Class"

def __init__(
self,
Expand Down
Loading

0 comments on commit c553325

Please sign in to comment.