From 9be37bcdc0dd0d3cdde175ef77c35d5e0ebe6432 Mon Sep 17 00:00:00 2001 From: Benj Fassbind Date: Sun, 21 Apr 2024 23:21:03 +0200 Subject: [PATCH 1/4] Add label names to roc curve plots --- src/torchmetrics/classification/roc.py | 6 ++++++ src/torchmetrics/utilities/plot.py | 14 +++++++++++++- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/src/torchmetrics/classification/roc.py b/src/torchmetrics/classification/roc.py index 68edf6e8fcc..2b69df1dac2 100644 --- a/src/torchmetrics/classification/roc.py +++ b/src/torchmetrics/classification/roc.py @@ -297,6 +297,7 @@ def plot( curve: Optional[Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]] = None, score: Optional[Union[Tensor, bool]] = None, ax: Optional[_AX_TYPE] = None, + labels: Optional[List[str]] = None, ) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. @@ -307,6 +308,7 @@ def plot( will automatically compute the score. The score is computed by using the trapezoidal rule to compute the area under the curve. ax: An matplotlib axis object. If provided will add plot to that axis + labels: a list of strings, if provided will be added to the plot to indicate the different classes Returns: Figure and Axes object @@ -337,6 +339,7 @@ def plot( ax=ax, label_names=("False positive rate", "True positive rate"), name=self.__class__.__name__, + labels=labels, ) @@ -456,6 +459,7 @@ def plot( curve: Optional[Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]] = None, score: Optional[Union[Tensor, bool]] = None, ax: Optional[_AX_TYPE] = None, + labels: Optional[List[str]] = None, ) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. @@ -466,6 +470,7 @@ def plot( will automatically compute the score. The score is computed by using the trapezoidal rule to compute the area under the curve. ax: An matplotlib axis object. If provided will add plot to that axis + labels: a list of strings, if provided will be added to the plot to indicate the different classes Returns: Figure and Axes object @@ -496,6 +501,7 @@ def plot( ax=ax, label_names=("False positive rate", "True positive rate"), name=self.__class__.__name__, + labels=labels, ) diff --git a/src/torchmetrics/utilities/plot.py b/src/torchmetrics/utilities/plot.py index b1ec17597b7..785a9499018 100644 --- a/src/torchmetrics/utilities/plot.py +++ b/src/torchmetrics/utilities/plot.py @@ -274,6 +274,7 @@ def plot_curve( label_names: Optional[Tuple[str, str]] = None, legend_name: Optional[str] = None, name: Optional[str] = None, + labels: Optional[List[Union[int, str]]] = None, ) -> _PLOT_OUT_TYPE: """Inspired by: https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/metrics/_plot/roc_curve.py. @@ -301,6 +302,7 @@ def plot_curve( raise ValueError("Expected 2 or 3 elements in curve but got {len(curve)}") x, y = curve[:2] + _error_on_missing_matplotlib() fig, ax = plt.subplots() if ax is None else (None, ax) @@ -312,8 +314,18 @@ def plot_curve( elif (isinstance(x, list) and isinstance(y, list)) or ( isinstance(x, Tensor) and isinstance(y, Tensor) and x.ndim == 2 and y.ndim == 2 ): + n_classes = len(x) + if labels is not None and len(labels) != n_classes: + raise ValueError( + "Expected number of elements in arg `labels` to match number of labels in roc curves but " + f"got {len(labels)} and {n_classes}" + ) + for i, (x_, y_) in enumerate(zip(x, y)): - label = f"{legend_name}_{i}" if legend_name is not None else str(i) + if labels is None: + label = f"{legend_name}_{i}" if legend_name is not None else str(i) + else: + label = labels[i] label += f" AUC={score[i].item():0.3f}" if score is not None else "" ax.plot(x_.detach().cpu(), y_.detach().cpu(), linestyle="-", linewidth=2, label=label) ax.legend() From 54bfa08133274c52720a2421e2d7378aa0a18087 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 21 Apr 2024 21:23:37 +0000 Subject: [PATCH 2/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/utilities/plot.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/torchmetrics/utilities/plot.py b/src/torchmetrics/utilities/plot.py index 785a9499018..3ad0a89a67f 100644 --- a/src/torchmetrics/utilities/plot.py +++ b/src/torchmetrics/utilities/plot.py @@ -302,7 +302,6 @@ def plot_curve( raise ValueError("Expected 2 or 3 elements in curve but got {len(curve)}") x, y = curve[:2] - _error_on_missing_matplotlib() fig, ax = plt.subplots() if ax is None else (None, ax) From 983a307360d164343b22baabe3898dda27801a9f Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Mon, 22 Apr 2024 10:05:26 +0200 Subject: [PATCH 3/4] Update src/torchmetrics/utilities/plot.py --- src/torchmetrics/utilities/plot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/utilities/plot.py b/src/torchmetrics/utilities/plot.py index 3ad0a89a67f..26527d63256 100644 --- a/src/torchmetrics/utilities/plot.py +++ b/src/torchmetrics/utilities/plot.py @@ -324,7 +324,7 @@ def plot_curve( if labels is None: label = f"{legend_name}_{i}" if legend_name is not None else str(i) else: - label = labels[i] + label = str(labels[i]) label += f" AUC={score[i].item():0.3f}" if score is not None else "" ax.plot(x_.detach().cpu(), y_.detach().cpu(), linestyle="-", linewidth=2, label=label) ax.legend() From bd3baae81912b72ee76468beb66e44e3fcf2c465 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 22 Apr 2024 11:09:20 +0200 Subject: [PATCH 4/4] fix pre-commit --- src/torchmetrics/utilities/plot.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/torchmetrics/utilities/plot.py b/src/torchmetrics/utilities/plot.py index 26527d63256..c930858ef79 100644 --- a/src/torchmetrics/utilities/plot.py +++ b/src/torchmetrics/utilities/plot.py @@ -288,6 +288,7 @@ def plot_curve( label_names: Tuple containing the names of the x and y axis legend_name: Name of the curve to be used in the legend name: Custom name to describe the metric + labels: Optional labels for the different curves that will be added to the plot Returns: A tuple consisting of the figure and respective ax objects (or array of ax objects) of the generated figure @@ -321,10 +322,7 @@ def plot_curve( ) for i, (x_, y_) in enumerate(zip(x, y)): - if labels is None: - label = f"{legend_name}_{i}" if legend_name is not None else str(i) - else: - label = str(labels[i]) + label = f"{legend_name}_{i}" if legend_name is not None else str(i) if labels is None else str(labels[i]) label += f" AUC={score[i].item():0.3f}" if score is not None else "" ax.plot(x_.detach().cpu(), y_.detach().cpu(), linestyle="-", linewidth=2, label=label) ax.legend()