Skip to content

Commit

Permalink
Merge branch 'master' into ci/gpu-pt2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored Mar 31, 2023
2 parents fdc54a6 + 2c90745 commit 1ebdb07
Show file tree
Hide file tree
Showing 10 changed files with 735 additions and 80 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[#1638](https://github.com/Lightning-AI/metrics/pull/1638),
[#1631](https://github.com/Lightning-AI/metrics/pull/1631),
[#1639](https://github.com/Lightning-AI/metrics/pull/1639),
[#1660](https://github.com/Lightning-AI/metrics/pull/1660)
)


Expand Down
28 changes: 28 additions & 0 deletions examples/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,9 +393,37 @@ def mean_average_precision() -> tuple:
return fig, ax


def roc_example() -> tuple:
"""Plot roc metric."""
from torchmetrics.classification import BinaryROC, MulticlassROC, MultilabelROC

p = lambda: torch.rand(20)
t = lambda: torch.randint(2, (20,))

metric = BinaryROC()
metric.update(p(), t())
fig, ax = metric.plot()

p = lambda: torch.randn(200, 5)
t = lambda: torch.randint(5, (200,))

metric = MulticlassROC(5)
metric.update(p(), t())
fig, ax = metric.plot()

p = lambda: torch.rand(20, 2)
t = lambda: torch.randint(2, (20, 2))

metric = MultilabelROC(2)
metric.update(p(), t())

return fig, ax


if __name__ == "__main__":
metrics_func = {
"accuracy": accuracy_example,
"roc": roc_example,
"pesq": pesq_example,
"pit": pit_example,
"sdr": sdr_example,
Expand Down
260 changes: 259 additions & 1 deletion src/torchmetrics/classification/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional
from typing import Any, Optional, Sequence, Union

from torch import Tensor
from typing_extensions import Literal
Expand All @@ -25,6 +25,18 @@
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.enums import ClassificationTask
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = [
"BinaryFBetaScore.plot",
"MulticlassFBetaScore.plot",
"MultilabelFBetaScore.plot",
"BinaryF1Score.plot",
"MulticlassF1Score.plot",
"MultilabelF1Score.plot",
]


class BinaryFBetaScore(BinaryStatScores):
Expand Down Expand Up @@ -124,6 +136,47 @@ def compute(self) -> Tensor:
tp, fp, tn, fn = self._final_state()
return _fbeta_reduce(tp, fp, tn, fn, self.beta, average="binary", multidim_average=self.multidim_average)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure object and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> from torch import rand, randint
>>> # Example plotting a single value
>>> from torchmetrics.classification import BinaryFBetaScore
>>> metric = BinaryFBetaScore(beta=2.0)
>>> metric.update(rand(10), randint(2,(10,)))
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> from torch import rand, randint
>>> # Example plotting multiple values
>>> from torchmetrics.classification import BinaryFBetaScore
>>> metric = BinaryFBetaScore(beta=2.0)
>>> values = [ ]
>>> for _ in range(10):
... values.append(metric(rand(10), randint(2,(10,))))
>>> fig_, ax_ = metric.plot(values)
"""
return self._plot(val, ax)


class MulticlassFBetaScore(MulticlassStatScores):
r"""Compute `F-score`_ metric for multiclass tasks.
Expand Down Expand Up @@ -258,6 +311,47 @@ def compute(self) -> Tensor:
tp, fp, tn, fn = self._final_state()
return _fbeta_reduce(tp, fp, tn, fn, self.beta, average=self.average, multidim_average=self.multidim_average)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure object and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> from torch import randint
>>> # Example plotting a single value per class
>>> from torchmetrics.classification import MulticlassFBetaScore
>>> metric = MulticlassFBetaScore(num_classes=3, beta=2.0, average=None)
>>> metric.update(randint(3, (20,)), randint(3, (20,)))
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> from torch import randint
>>> # Example plotting a multiple values per class
>>> from torchmetrics.classification import MulticlassFBetaScore
>>> metric = MulticlassFBetaScore(num_classes=3, beta=2.0, average=None)
>>> values = []
>>> for _ in range(20):
... values.append(metric(randint(3, (20,)), randint(3, (20,))))
>>> fig_, ax_ = metric.plot(values)
"""
return self._plot(val, ax)


class MultilabelFBetaScore(MultilabelStatScores):
r"""Compute `F-score`_ metric for multilabel tasks.
Expand Down Expand Up @@ -388,6 +482,47 @@ def compute(self) -> Tensor:
tp, fp, tn, fn = self._final_state()
return _fbeta_reduce(tp, fp, tn, fn, self.beta, average=self.average, multidim_average=self.multidim_average)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> from torch import rand, randint
>>> # Example plotting a single value
>>> from torchmetrics.classification import MultilabelFBetaScore
>>> metric = MultilabelFBetaScore(num_labels=3, beta=2.0)
>>> metric.update(randint(2, (20, 3)), randint(2, (20, 3)))
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> from torch import rand, randint
>>> # Example plotting multiple values
>>> from torchmetrics.classification import MultilabelFBetaScore
>>> metric = MultilabelFBetaScore(num_labels=3, beta=2.0)
>>> values = [ ]
>>> for _ in range(10):
... values.append(metric(randint(2, (20, 3)), randint(2, (20, 3))))
>>> fig_, ax_ = metric.plot(values)
"""
return self._plot(val, ax)


class BinaryF1Score(BinaryFBetaScore):
r"""Compute F-1 score for binary tasks.
Expand Down Expand Up @@ -475,6 +610,47 @@ def __init__(
**kwargs,
)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure object and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> from torch import rand, randint
>>> # Example plotting a single value
>>> from torchmetrics.classification import BinaryF1Score
>>> metric = BinaryF1Score()
>>> metric.update(rand(10), randint(2,(10,)))
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> from torch import rand, randint
>>> # Example plotting multiple values
>>> from torchmetrics.classification import BinaryF1Score
>>> metric = BinaryF1Score()
>>> values = [ ]
>>> for _ in range(10):
... values.append(metric(rand(10), randint(2,(10,))))
>>> fig_, ax_ = metric.plot(values)
"""
return self._plot(val, ax)


class MulticlassF1Score(MulticlassFBetaScore):
r"""Compute F-1 score for multiclass tasks.
Expand Down Expand Up @@ -598,6 +774,47 @@ def __init__(
**kwargs,
)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure object and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> from torch import randint
>>> # Example plotting a single value per class
>>> from torchmetrics.classification import MulticlassF1Score
>>> metric = MulticlassF1Score(num_classes=3, average=None)
>>> metric.update(randint(3, (20,)), randint(3, (20,)))
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> from torch import randint
>>> # Example plotting a multiple values per class
>>> from torchmetrics.classification import MulticlassF1Score
>>> metric = MulticlassF1Score(num_classes=3, average=None)
>>> values = []
>>> for _ in range(20):
... values.append(metric(randint(3, (20,)), randint(3, (20,))))
>>> fig_, ax_ = metric.plot(values)
"""
return self._plot(val, ax)


class MultilabelF1Score(MultilabelFBetaScore):
r"""Compute F-1 score for multilabel tasks.
Expand Down Expand Up @@ -717,6 +934,47 @@ def __init__(
**kwargs,
)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> from torch import rand, randint
>>> # Example plotting a single value
>>> from torchmetrics.classification import MultilabelF1Score
>>> metric = MultilabelF1Score(num_labels=3)
>>> metric.update(randint(2, (20, 3)), randint(2, (20, 3)))
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> from torch import rand, randint
>>> # Example plotting multiple values
>>> from torchmetrics.classification import MultilabelF1Score
>>> metric = MultilabelF1Score(num_labels=3)
>>> values = [ ]
>>> for _ in range(10):
... values.append(metric(randint(2, (20, 3)), randint(2, (20, 3))))
>>> fig_, ax_ = metric.plot(values)
"""
return self._plot(val, ax)


class FBetaScore:
r"""Compute `F-score`_ metric.
Expand Down
Loading

0 comments on commit 1ebdb07

Please sign in to comment.