From 7821012b973aa3f4174d7910aaab9cad0d45fc0c Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Sat, 4 Mar 2023 01:17:01 +0100 Subject: [PATCH] Add `plot` to nominal (#1581) --- CHANGELOG.md | 1 + docs/source/nominal/cramers_v.rst | 1 + .../pearsons_contingency_coefficient.rst | 1 + docs/source/nominal/theils_u.rst | 1 + docs/source/nominal/tschuprows_t.rst | 1 + src/torchmetrics/nominal/cramers.py | 48 ++++++++++++++++- src/torchmetrics/nominal/pearson.py | 48 ++++++++++++++++- src/torchmetrics/nominal/theils_u.py | 51 ++++++++++++++++++- src/torchmetrics/nominal/tschuprows.py | 48 ++++++++++++++++- tests/unittests/utilities/test_plot.py | 12 +++++ 10 files changed, 207 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 20bd9fda7be..de301949719 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [#1481](https://github.com/Lightning-AI/metrics/pull/1481), [#1480](https://github.com/Lightning-AI/metrics/pull/1480), [#1490](https://github.com/Lightning-AI/metrics/pull/1490), + [#1581](https://github.com/Lightning-AI/metrics/pull/1581), ) diff --git a/docs/source/nominal/cramers_v.rst b/docs/source/nominal/cramers_v.rst index 7e6533a730a..2a2dc4dccad 100644 --- a/docs/source/nominal/cramers_v.rst +++ b/docs/source/nominal/cramers_v.rst @@ -12,6 +12,7 @@ ________________ .. autoclass:: torchmetrics.CramersV :noindex: + :exclude-members: update, compute Functional Interface ____________________ diff --git a/docs/source/nominal/pearsons_contingency_coefficient.rst b/docs/source/nominal/pearsons_contingency_coefficient.rst index 6715ade54d6..039052b8780 100644 --- a/docs/source/nominal/pearsons_contingency_coefficient.rst +++ b/docs/source/nominal/pearsons_contingency_coefficient.rst @@ -12,6 +12,7 @@ ________________ .. autoclass:: torchmetrics.PearsonsContingencyCoefficient :noindex: + :exclude-members: update, compute Functional Interface ____________________ diff --git a/docs/source/nominal/theils_u.rst b/docs/source/nominal/theils_u.rst index 8f9089559bc..33a9a6e887a 100644 --- a/docs/source/nominal/theils_u.rst +++ b/docs/source/nominal/theils_u.rst @@ -12,6 +12,7 @@ ________________ .. autoclass:: torchmetrics.TheilsU :noindex: + :exclude-members: update, compute Functional Interface ____________________ diff --git a/docs/source/nominal/tschuprows_t.rst b/docs/source/nominal/tschuprows_t.rst index dafbea3f3e7..71b640e4e67 100644 --- a/docs/source/nominal/tschuprows_t.rst +++ b/docs/source/nominal/tschuprows_t.rst @@ -12,6 +12,7 @@ ________________ .. autoclass:: torchmetrics.TschuprowsT :noindex: + :exclude-members: update, compute Functional Interface ____________________ diff --git a/src/torchmetrics/nominal/cramers.py b/src/torchmetrics/nominal/cramers.py index ea86140dd08..a52810587b6 100644 --- a/src/torchmetrics/nominal/cramers.py +++ b/src/torchmetrics/nominal/cramers.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Union +from typing import Any, Optional, Sequence, Union import torch from torch import Tensor @@ -20,6 +20,11 @@ from torchmetrics.functional.nominal.cramers import _cramers_v_compute, _cramers_v_update from torchmetrics.functional.nominal.utils import _nominal_input_validation from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["CramersV.plot"] class CramersV(Metric): @@ -69,6 +74,8 @@ class CramersV(Metric): full_state_update: bool = False is_differentiable: bool = False higher_is_better: bool = True + plot_lower_bound = 0.0 + plot_upper_bound = 1.0 confmat: Tensor def __init__( @@ -109,3 +116,42 @@ def update(self, preds: Tensor, target: Tensor) -> None: def compute(self) -> Tensor: """Compute Cramer's V statistic.""" return _cramers_v_compute(self.confmat, self.bias_correction) + + def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics import CramersV + >>> metric = CramersV(num_classes=5) + >>> metric.update(torch.randint(0, 4, (100,)), torch.randint(0, 4, (100,))) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics import CramersV + >>> metric = CramersV(num_classes=5) + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(torch.randint(0, 4, (100,)), torch.randint(0, 4, (100,)))) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/nominal/pearson.py b/src/torchmetrics/nominal/pearson.py index a95fd7724ad..5729e4df373 100644 --- a/src/torchmetrics/nominal/pearson.py +++ b/src/torchmetrics/nominal/pearson.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Union +from typing import Any, Optional, Sequence, Union import torch from torch import Tensor @@ -23,6 +23,11 @@ ) from torchmetrics.functional.nominal.utils import _nominal_input_validation from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["PearsonsContingencyCoefficient.plot"] class PearsonsContingencyCoefficient(Metric): @@ -74,6 +79,8 @@ class PearsonsContingencyCoefficient(Metric): full_state_update: bool = False is_differentiable: bool = False higher_is_better: bool = True + plot_lower_bound = 0.0 + plot_upper_bound = 1.0 confmat: Tensor def __init__( @@ -114,3 +121,42 @@ def update(self, preds: Tensor, target: Tensor) -> None: def compute(self) -> Tensor: """Compute Pearson's Contingency Coefficient statistic.""" return _pearsons_contingency_coefficient_compute(self.confmat) + + def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics import PearsonsContingencyCoefficient + >>> metric = PearsonsContingencyCoefficient(num_classes=5) + >>> metric.update(torch.randint(0, 4, (100,)), torch.randint(0, 4, (100,))) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics import PearsonsContingencyCoefficient + >>> metric = PearsonsContingencyCoefficient(num_classes=5) + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(torch.randint(0, 4, (100,)), torch.randint(0, 4, (100,)))) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/nominal/theils_u.py b/src/torchmetrics/nominal/theils_u.py index 46d2ff30494..e12e4144939 100644 --- a/src/torchmetrics/nominal/theils_u.py +++ b/src/torchmetrics/nominal/theils_u.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Union +from typing import Any, Optional, Sequence, Union import torch from torch import Tensor @@ -20,6 +20,11 @@ from torchmetrics.functional.nominal.theils_u import _theils_u_compute, _theils_u_update from torchmetrics.functional.nominal.utils import _nominal_input_validation from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["TheilsU.plot"] class TheilsU(Metric): @@ -50,13 +55,16 @@ class TheilsU(Metric): >>> _ = torch.manual_seed(42) >>> preds = torch.randint(10, (10,)) >>> target = torch.randint(10, (10,)) - >>> TheilsU(num_classes=10)(preds, target) + >>> metric = TheilsU(num_classes=10) + >>> metric(preds, target) tensor(0.8530) """ full_state_update: bool = False is_differentiable: bool = False higher_is_better: bool = True + plot_lower_bound = 0.0 + plot_upper_bound = 1.0 confmat: Tensor def __init__( @@ -92,3 +100,42 @@ def update(self, preds: Tensor, target: Tensor) -> None: def compute(self) -> Tensor: """Compute Theil's U statistic.""" return _theils_u_compute(self.confmat) + + def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics import TheilsU + >>> metric = TheilsU(num_classes=10) + >>> metric.update(torch.randint(10, (10,)), torch.randint(10, (10,))) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics import TheilsU + >>> metric = TheilsU(num_classes=10) + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(torch.randint(10, (10,)), torch.randint(10, (10,)))) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/nominal/tschuprows.py b/src/torchmetrics/nominal/tschuprows.py index 7f90605a331..589e91bf738 100644 --- a/src/torchmetrics/nominal/tschuprows.py +++ b/src/torchmetrics/nominal/tschuprows.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Union +from typing import Any, Optional, Sequence, Union import torch from torch import Tensor @@ -20,6 +20,11 @@ from torchmetrics.functional.nominal.tschuprows import _tschuprows_t_compute, _tschuprows_t_update from torchmetrics.functional.nominal.utils import _nominal_input_validation from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["TschuprowsT.plot"] class TschuprowsT(Metric): @@ -69,6 +74,8 @@ class TschuprowsT(Metric): full_state_update: bool = False is_differentiable: bool = False higher_is_better: bool = True + plot_lower_bound = 0.0 + plot_upper_bound = 1.0 confmat: Tensor def __init__( @@ -109,3 +116,42 @@ def update(self, preds: Tensor, target: Tensor) -> None: def compute(self) -> Tensor: """Compute Tschuprow's T statistic.""" return _tschuprows_t_compute(self.confmat, self.bias_correction) + + def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics import TschuprowsT + >>> metric = TschuprowsT(num_classes=5) + >>> metric.update(torch.randint(0, 4, (100,)), torch.randint(0, 4, (100,))) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics import TschuprowsT + >>> metric = TschuprowsT(num_classes=5) + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(torch.randint(0, 4, (100,)), torch.randint(0, 4, (100,)))) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index 5ba80cab4f6..659bd9dfdfa 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -50,6 +50,7 @@ StructuralSimilarityIndexMeasure, UniversalImageQualityIndex, ) +from torchmetrics.nominal import CramersV, PearsonsContingencyCoefficient, TheilsU, TschuprowsT from torchmetrics.regression import MeanSquaredError _rand_input = lambda: torch.rand(10) @@ -59,6 +60,7 @@ _multilabel_randint_input = lambda: torch.randint(2, (10, 3)) _audio_input = lambda: torch.randn(8000) _image_input = lambda: torch.rand([8, 3, 16, 16]) +_nominal_input = lambda: torch.randint(0, 4, (100,)) @pytest.mark.parametrize( @@ -100,7 +102,17 @@ BinaryROC, _rand_input, _binary_randint_input, + id="binary roc", ), + pytest.param( + partial(PearsonsContingencyCoefficient, num_classes=5), + _nominal_input, + _nominal_input, + id="pearson contigency coef", + ), + pytest.param(partial(TheilsU, num_classes=5), _nominal_input, _nominal_input, id="theils U"), + pytest.param(partial(TschuprowsT, num_classes=5), _nominal_input, _nominal_input, id="tschuprows T"), + pytest.param(partial(CramersV, num_classes=5), _nominal_input, _nominal_input, id="cramers V"), pytest.param( SpectralDistortionIndex, _image_input,