Skip to content

Commit

Permalink
Add label names to roc curve plots (#2511)
Browse files Browse the repository at this point in the history
* Add label names to roc curve plots

* Update src/torchmetrics/utilities/plot.py

* fix pre-commit

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Nicki Skafte Detlefsen <skaftenicki@gmail.com>
  • Loading branch information
3 people authored Apr 22, 2024
1 parent cd7ccfc commit 0b33bd0
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
6 changes: 6 additions & 0 deletions src/torchmetrics/classification/roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ def plot(
curve: Optional[Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]] = None,
score: Optional[Union[Tensor, bool]] = None,
ax: Optional[_AX_TYPE] = None,
labels: Optional[List[str]] = None,
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Expand All @@ -307,6 +308,7 @@ def plot(
will automatically compute the score. The score is computed by using the trapezoidal rule to compute the
area under the curve.
ax: An matplotlib axis object. If provided will add plot to that axis
labels: a list of strings, if provided will be added to the plot to indicate the different classes
Returns:
Figure and Axes object
Expand Down Expand Up @@ -337,6 +339,7 @@ def plot(
ax=ax,
label_names=("False positive rate", "True positive rate"),
name=self.__class__.__name__,
labels=labels,
)


Expand Down Expand Up @@ -456,6 +459,7 @@ def plot(
curve: Optional[Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]] = None,
score: Optional[Union[Tensor, bool]] = None,
ax: Optional[_AX_TYPE] = None,
labels: Optional[List[str]] = None,
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Expand All @@ -466,6 +470,7 @@ def plot(
will automatically compute the score. The score is computed by using the trapezoidal rule to compute the
area under the curve.
ax: An matplotlib axis object. If provided will add plot to that axis
labels: a list of strings, if provided will be added to the plot to indicate the different classes
Returns:
Figure and Axes object
Expand Down Expand Up @@ -496,6 +501,7 @@ def plot(
ax=ax,
label_names=("False positive rate", "True positive rate"),
name=self.__class__.__name__,
labels=labels,
)


Expand Down
11 changes: 10 additions & 1 deletion src/torchmetrics/utilities/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ def plot_curve(
label_names: Optional[Tuple[str, str]] = None,
legend_name: Optional[str] = None,
name: Optional[str] = None,
labels: Optional[List[Union[int, str]]] = None,
) -> _PLOT_OUT_TYPE:
"""Inspired by: https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/metrics/_plot/roc_curve.py.
Expand All @@ -287,6 +288,7 @@ def plot_curve(
label_names: Tuple containing the names of the x and y axis
legend_name: Name of the curve to be used in the legend
name: Custom name to describe the metric
labels: Optional labels for the different curves that will be added to the plot
Returns:
A tuple consisting of the figure and respective ax objects (or array of ax objects) of the generated figure
Expand All @@ -312,8 +314,15 @@ def plot_curve(
elif (isinstance(x, list) and isinstance(y, list)) or (
isinstance(x, Tensor) and isinstance(y, Tensor) and x.ndim == 2 and y.ndim == 2
):
n_classes = len(x)
if labels is not None and len(labels) != n_classes:
raise ValueError(
"Expected number of elements in arg `labels` to match number of labels in roc curves but "
f"got {len(labels)} and {n_classes}"
)

for i, (x_, y_) in enumerate(zip(x, y)):
label = f"{legend_name}_{i}" if legend_name is not None else str(i)
label = f"{legend_name}_{i}" if legend_name is not None else str(i) if labels is None else str(labels[i])
label += f" AUC={score[i].item():0.3f}" if score is not None else ""
ax.plot(x_.detach().cpu(), y_.detach().cpu(), linestyle="-", linewidth=2, label=label)
ax.legend()
Expand Down

0 comments on commit 0b33bd0

Please sign in to comment.