Skip to content

Commit

Permalink
Merge branch 'master' into fairness
Browse files Browse the repository at this point in the history
  • Loading branch information
mergify[bot] authored Mar 4, 2023
2 parents 5be0127 + 7821012 commit d0f7b17
Show file tree
Hide file tree
Showing 10 changed files with 207 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[#1481](https://github.com/Lightning-AI/metrics/pull/1481),
[#1480](https://github.com/Lightning-AI/metrics/pull/1480),
[#1490](https://github.com/Lightning-AI/metrics/pull/1490),
[#1581](https://github.com/Lightning-AI/metrics/pull/1581),
)


Expand Down
1 change: 1 addition & 0 deletions docs/source/nominal/cramers_v.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ ________________

.. autoclass:: torchmetrics.CramersV
:noindex:
:exclude-members: update, compute

Functional Interface
____________________
Expand Down
1 change: 1 addition & 0 deletions docs/source/nominal/pearsons_contingency_coefficient.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ ________________

.. autoclass:: torchmetrics.PearsonsContingencyCoefficient
:noindex:
:exclude-members: update, compute

Functional Interface
____________________
Expand Down
1 change: 1 addition & 0 deletions docs/source/nominal/theils_u.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ ________________

.. autoclass:: torchmetrics.TheilsU
:noindex:
:exclude-members: update, compute

Functional Interface
____________________
Expand Down
1 change: 1 addition & 0 deletions docs/source/nominal/tschuprows_t.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ ________________

.. autoclass:: torchmetrics.TschuprowsT
:noindex:
:exclude-members: update, compute

Functional Interface
____________________
Expand Down
48 changes: 47 additions & 1 deletion src/torchmetrics/nominal/cramers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Union
from typing import Any, Optional, Sequence, Union

import torch
from torch import Tensor
Expand All @@ -20,6 +20,11 @@
from torchmetrics.functional.nominal.cramers import _cramers_v_compute, _cramers_v_update
from torchmetrics.functional.nominal.utils import _nominal_input_validation
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["CramersV.plot"]


class CramersV(Metric):
Expand Down Expand Up @@ -69,6 +74,8 @@ class CramersV(Metric):
full_state_update: bool = False
is_differentiable: bool = False
higher_is_better: bool = True
plot_lower_bound = 0.0
plot_upper_bound = 1.0
confmat: Tensor

def __init__(
Expand Down Expand Up @@ -109,3 +116,42 @@ def update(self, preds: Tensor, target: Tensor) -> None:
def compute(self) -> Tensor:
"""Compute Cramer's V statistic."""
return _cramers_v_compute(self.confmat, self.bias_correction)

def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics import CramersV
>>> metric = CramersV(num_classes=5)
>>> metric.update(torch.randint(0, 4, (100,)), torch.randint(0, 4, (100,)))
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics import CramersV
>>> metric = CramersV(num_classes=5)
>>> values = [ ]
>>> for _ in range(10):
... values.append(metric(torch.randint(0, 4, (100,)), torch.randint(0, 4, (100,))))
>>> fig_, ax_ = metric.plot(values)
"""
return self._plot(val, ax)
48 changes: 47 additions & 1 deletion src/torchmetrics/nominal/pearson.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Union
from typing import Any, Optional, Sequence, Union

import torch
from torch import Tensor
Expand All @@ -23,6 +23,11 @@
)
from torchmetrics.functional.nominal.utils import _nominal_input_validation
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["PearsonsContingencyCoefficient.plot"]


class PearsonsContingencyCoefficient(Metric):
Expand Down Expand Up @@ -74,6 +79,8 @@ class PearsonsContingencyCoefficient(Metric):
full_state_update: bool = False
is_differentiable: bool = False
higher_is_better: bool = True
plot_lower_bound = 0.0
plot_upper_bound = 1.0
confmat: Tensor

def __init__(
Expand Down Expand Up @@ -114,3 +121,42 @@ def update(self, preds: Tensor, target: Tensor) -> None:
def compute(self) -> Tensor:
"""Compute Pearson's Contingency Coefficient statistic."""
return _pearsons_contingency_coefficient_compute(self.confmat)

def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics import PearsonsContingencyCoefficient
>>> metric = PearsonsContingencyCoefficient(num_classes=5)
>>> metric.update(torch.randint(0, 4, (100,)), torch.randint(0, 4, (100,)))
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics import PearsonsContingencyCoefficient
>>> metric = PearsonsContingencyCoefficient(num_classes=5)
>>> values = [ ]
>>> for _ in range(10):
... values.append(metric(torch.randint(0, 4, (100,)), torch.randint(0, 4, (100,))))
>>> fig_, ax_ = metric.plot(values)
"""
return self._plot(val, ax)
51 changes: 49 additions & 2 deletions src/torchmetrics/nominal/theils_u.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Union
from typing import Any, Optional, Sequence, Union

import torch
from torch import Tensor
Expand All @@ -20,6 +20,11 @@
from torchmetrics.functional.nominal.theils_u import _theils_u_compute, _theils_u_update
from torchmetrics.functional.nominal.utils import _nominal_input_validation
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["TheilsU.plot"]


class TheilsU(Metric):
Expand Down Expand Up @@ -50,13 +55,16 @@ class TheilsU(Metric):
>>> _ = torch.manual_seed(42)
>>> preds = torch.randint(10, (10,))
>>> target = torch.randint(10, (10,))
>>> TheilsU(num_classes=10)(preds, target)
>>> metric = TheilsU(num_classes=10)
>>> metric(preds, target)
tensor(0.8530)
"""

full_state_update: bool = False
is_differentiable: bool = False
higher_is_better: bool = True
plot_lower_bound = 0.0
plot_upper_bound = 1.0
confmat: Tensor

def __init__(
Expand Down Expand Up @@ -92,3 +100,42 @@ def update(self, preds: Tensor, target: Tensor) -> None:
def compute(self) -> Tensor:
"""Compute Theil's U statistic."""
return _theils_u_compute(self.confmat)

def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics import TheilsU
>>> metric = TheilsU(num_classes=10)
>>> metric.update(torch.randint(10, (10,)), torch.randint(10, (10,)))
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics import TheilsU
>>> metric = TheilsU(num_classes=10)
>>> values = [ ]
>>> for _ in range(10):
... values.append(metric(torch.randint(10, (10,)), torch.randint(10, (10,))))
>>> fig_, ax_ = metric.plot(values)
"""
return self._plot(val, ax)
48 changes: 47 additions & 1 deletion src/torchmetrics/nominal/tschuprows.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Union
from typing import Any, Optional, Sequence, Union

import torch
from torch import Tensor
Expand All @@ -20,6 +20,11 @@
from torchmetrics.functional.nominal.tschuprows import _tschuprows_t_compute, _tschuprows_t_update
from torchmetrics.functional.nominal.utils import _nominal_input_validation
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["TschuprowsT.plot"]


class TschuprowsT(Metric):
Expand Down Expand Up @@ -69,6 +74,8 @@ class TschuprowsT(Metric):
full_state_update: bool = False
is_differentiable: bool = False
higher_is_better: bool = True
plot_lower_bound = 0.0
plot_upper_bound = 1.0
confmat: Tensor

def __init__(
Expand Down Expand Up @@ -109,3 +116,42 @@ def update(self, preds: Tensor, target: Tensor) -> None:
def compute(self) -> Tensor:
"""Compute Tschuprow's T statistic."""
return _tschuprows_t_compute(self.confmat, self.bias_correction)

def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics import TschuprowsT
>>> metric = TschuprowsT(num_classes=5)
>>> metric.update(torch.randint(0, 4, (100,)), torch.randint(0, 4, (100,)))
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics import TschuprowsT
>>> metric = TschuprowsT(num_classes=5)
>>> values = [ ]
>>> for _ in range(10):
... values.append(metric(torch.randint(0, 4, (100,)), torch.randint(0, 4, (100,))))
>>> fig_, ax_ = metric.plot(values)
"""
return self._plot(val, ax)
12 changes: 12 additions & 0 deletions tests/unittests/utilities/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
StructuralSimilarityIndexMeasure,
UniversalImageQualityIndex,
)
from torchmetrics.nominal import CramersV, PearsonsContingencyCoefficient, TheilsU, TschuprowsT
from torchmetrics.regression import MeanSquaredError

_rand_input = lambda: torch.rand(10)
Expand All @@ -59,6 +60,7 @@
_multilabel_randint_input = lambda: torch.randint(2, (10, 3))
_audio_input = lambda: torch.randn(8000)
_image_input = lambda: torch.rand([8, 3, 16, 16])
_nominal_input = lambda: torch.randint(0, 4, (100,))


@pytest.mark.parametrize(
Expand Down Expand Up @@ -100,7 +102,17 @@
BinaryROC,
_rand_input,
_binary_randint_input,
id="binary roc",
),
pytest.param(
partial(PearsonsContingencyCoefficient, num_classes=5),
_nominal_input,
_nominal_input,
id="pearson contigency coef",
),
pytest.param(partial(TheilsU, num_classes=5), _nominal_input, _nominal_input, id="theils U"),
pytest.param(partial(TschuprowsT, num_classes=5), _nominal_input, _nominal_input, id="tschuprows T"),
pytest.param(partial(CramersV, num_classes=5), _nominal_input, _nominal_input, id="cramers V"),
pytest.param(
SpectralDistortionIndex,
_image_input,
Expand Down

0 comments on commit d0f7b17

Please sign in to comment.