Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add plotting 17/n #1650

Merged
merged 14 commits into from
Mar 31, 2023
Merged
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[#1623](https://github.com/Lightning-AI/metrics/pull/1623),
[#1638](https://github.com/Lightning-AI/metrics/pull/1638),
[#1631](https://github.com/Lightning-AI/metrics/pull/1631),
[#1650](https://github.com/Lightning-AI/metrics/pull/1650),
[#1639](https://github.com/Lightning-AI/metrics/pull/1639),
[#1660](https://github.com/Lightning-AI/metrics/pull/1660)
)
Expand Down
2 changes: 2 additions & 0 deletions requirements/docs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,5 @@ sphinx-copybutton>=0.3
-r detection.txt
-r image.txt
-r multimodal.txt
-r text.txt
-r text_test.txt
6 changes: 5 additions & 1 deletion src/torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,11 @@ def plot(self, *_: Any, **__: Any) -> Any:
"""Override this method plot the metric value."""
raise NotImplementedError

def _plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE:
def _plot(
self,
val: Optional[Union[Tensor, Sequence[Tensor], Dict[str, Tensor], Sequence[Dict[str, Tensor]]]] = None,
ax: Optional[_AX_TYPE] = None,
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.

Args:
Expand Down
60 changes: 56 additions & 4 deletions src/torchmetrics/text/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
from warnings import warn

import torch
Expand All @@ -23,7 +23,11 @@
from torchmetrics.functional.text.helper_embedding_metric import _preprocess_text
from torchmetrics.metric import Metric
from torchmetrics.utilities.checks import _SKIP_SLOW_DOCTEST, _try_proceed_with_timeout
from torchmetrics.utilities.imports import _TRANSFORMERS_AVAILABLE
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TRANSFORMERS_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["BERTScore.plot"]

# Default model recommended in the original implementation.
_DEFAULT_MODEL = "roberta-large"
Expand All @@ -37,9 +41,9 @@ def _download_model() -> None:
AutoModel.from_pretrained(_DEFAULT_MODEL)

if _SKIP_SLOW_DOCTEST and not _try_proceed_with_timeout(_download_model):
__doctest_skip__ = ["BERTScore"]
__doctest_skip__ = ["BERTScore", "BERTScore.plot"]
else:
__doctest_skip__ = ["BERTScore"]
__doctest_skip__ = ["BERTScore", "BERTScore.plot"]


def _get_input_dict(input_ids: List[Tensor], attention_mask: List[Tensor]) -> Dict[str, Tensor]:
Expand Down Expand Up @@ -240,3 +244,51 @@ def compute(self) -> Dict[str, Union[List[float], str]]:
baseline_path=self.baseline_path,
baseline_url=self.baseline_url,
)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.

Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis

Returns:
Figure and Axes object

Raises:
ModuleNotFoundError:
If `matplotlib` is not installed

.. plot::
:scale: 75

>>> # Example plotting a single value
>>> from torchmetrics.text.bert import BERTScore
>>> metric = BERTScore()
>>> preds = ["hello there", "general kenobi"]
>>> target = ["hello there", "master kenobi"]
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()

.. plot::
:scale: 75

>>> # Example plotting multiple values
>>> from torchmetrics.text.bert import BERTScore
>>> metric = BERTScore()
>>> preds = ["hello there", "general kenobi"]
>>> target = ["hello there", "master kenobi"]
>>> values = [ ]
>>> for _ in range(10):
... val = metric(preds, target)
... val = {k: torch.tensor(v).mean() for k,v in val.items()} # convert into single value per key
... values.append(val)
>>> fig_, ax_ = metric.plot(values)
"""
if val is None: # default average score across sentences
val = self.compute() # type: ignore
val = {k: torch.tensor(v).mean() for k, v in val.items()} # type: ignore
return self._plot(val, ax)
50 changes: 49 additions & 1 deletion src/torchmetrics/text/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,18 @@
# Authors: torchtext authors and @sluks
# Date: 2020-07-18
# Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score
from typing import Any, Optional, Sequence
from typing import Any, Optional, Sequence, Union

import torch
from torch import Tensor, tensor

from torchmetrics import Metric
from torchmetrics.functional.text.bleu import _bleu_score_compute, _bleu_score_update, _tokenize_fn
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["BLEUScore.plot"]


class BLEUScore(Metric):
Expand Down Expand Up @@ -103,3 +108,46 @@ def compute(self) -> Tensor:
return _bleu_score_compute(
self.preds_len, self.target_len, self.numerator, self.denominator, self.n_gram, self.weights, self.smooth
)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.

Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis

Returns:
Figure and Axes object

Raises:
ModuleNotFoundError:
If `matplotlib` is not installed

.. plot::
:scale: 75

>>> # Example plotting a single value
>>> from torchmetrics import BLEUScore
>>> metric = BLEUScore()
>>> preds = ['the cat is on the mat']
>>> target = [['there is a cat on the mat', 'a cat is on the mat']]
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()

.. plot::
:scale: 75

>>> # Example plotting multiple values
>>> from torchmetrics import BLEUScore
>>> metric = BLEUScore()
>>> preds = ['the cat is on the mat']
>>> target = [['there is a cat on the mat', 'a cat is on the mat']]
>>> values = [ ]
>>> for _ in range(10):
... values.append(metric(preds, target))
>>> fig_, ax_ = metric.plot(values)
"""
return self._plot(val, ax)
49 changes: 49 additions & 0 deletions src/torchmetrics/text/chrf.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@

from torchmetrics import Metric
from torchmetrics.functional.text.chrf import _chrf_score_compute, _chrf_score_update, _prepare_n_grams_dicts
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["CHRFScore.plot"]


_N_GRAM_LEVELS = ("char", "word")
_TEXT_LEVELS = ("preds", "target", "matching")
Expand Down Expand Up @@ -197,3 +203,46 @@ def _get_state_name(text: str, n_gram_level: str, n: int) -> str:
def _get_text_n_gram_iterator(self) -> Iterator[Tuple[Tuple[str, int], str]]:
"""Get iterator over char/word and reference/hypothesis/matching n-gram level."""
return itertools.product(zip(_N_GRAM_LEVELS, [self.n_char_order, self.n_word_order]), _TEXT_LEVELS)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.

Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis

Returns:
Figure and Axes object

Raises:
ModuleNotFoundError:
If `matplotlib` is not installed

.. plot::
:scale: 75

>>> # Example plotting a single value
>>> from torchmetrics import CHRFScore
>>> metric = CHRFScore()
>>> preds = ['the cat is on the mat']
>>> target = [['there is a cat on the mat', 'a cat is on the mat']]
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()

.. plot::
:scale: 75

>>> # Example plotting multiple values
>>> from torchmetrics import CHRFScore
>>> metric = CHRFScore()
>>> preds = ['the cat is on the mat']
>>> target = [['there is a cat on the mat', 'a cat is on the mat']]
>>> values = [ ]
>>> for _ in range(10):
... values.append(metric(preds, target))
>>> fig_, ax_ = metric.plot(values)
"""
return self._plot(val, ax)
51 changes: 49 additions & 2 deletions src/torchmetrics/text/infolm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,14 @@
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.imports import _TRANSFORMERS_AVAILABLE
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TRANSFORMERS_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["InfoLM.plot"]

if not _TRANSFORMERS_AVAILABLE:
__doctest_skip__ = ["InfoLM"]
__doctest_skip__ = ["InfoLM", "InfoLM.plot"]


class InfoLM(Metric):
Expand Down Expand Up @@ -193,3 +197,46 @@ def compute(self) -> Union[Tensor, Tuple[Tensor, Tensor]]:
return info_lm_score.mean(), info_lm_score

return info_lm_score.mean()

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.

Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis

Returns:
Figure and Axes object

Raises:
ModuleNotFoundError:
If `matplotlib` is not installed

.. plot::
:scale: 75

>>> # Example plotting a single value
>>> from torchmetrics.text.infolm import InfoLM
>>> metric = InfoLM('google/bert_uncased_L-2_H-128_A-2', idf=False)
>>> preds = ['he read the book because he was interested in world history']
>>> target = ['he was interested in world history because he read the book']
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()

.. plot::
:scale: 75

>>> # Example plotting multiple values
>>> from torchmetrics.text.infolm import InfoLM
>>> metric = InfoLM('google/bert_uncased_L-2_H-128_A-2', idf=False)
>>> preds = ["this is the prediction", "there is an other sample"]
>>> target = ["this is the reference", "there is another one"]
>>> values = [ ]
>>> for _ in range(10):
... values.append(metric(preds, target))
>>> fig_, ax_ = metric.plot(values)
"""
return self._plot(val, ax)
49 changes: 48 additions & 1 deletion src/torchmetrics/text/perplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Sequence, Union

from torch import Tensor, tensor

from torchmetrics.functional.text.perplexity import _perplexity_compute, _perplexity_update
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["Perplexity.plot"]


class Perplexity(Metric):
Expand All @@ -42,6 +47,7 @@ class Perplexity(Metric):
Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Examples:
>>> from torchmetrics import Perplexity
>>> import torch
>>> preds = torch.rand(2, 8, 5, generator=torch.manual_seed(22))
>>> target = torch.randint(5, (2, 8), generator=torch.manual_seed(22))
Expand Down Expand Up @@ -77,3 +83,44 @@ def update(self, preds: Tensor, target: Tensor) -> None:
def compute(self) -> Tensor:
"""Compute the Perplexity."""
return _perplexity_compute(self.total_log_probs, self.count)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.

Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis

Returns:
Figure and Axes object

Raises:
ModuleNotFoundError:
If `matplotlib` is not installed

.. plot::
:scale: 75

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics import Perplexity
>>> metric = Perplexity()
>>> metric.update(torch.rand(2, 8, 5), torch.randint(5, (2, 8)))
>>> fig_, ax_ = metric.plot()

.. plot::
:scale: 75

>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics import Perplexity
>>> metric = Perplexity()
>>> values = [ ]
>>> for _ in range(10):
... values.append(metric(torch.rand(2, 8, 5), torch.randint(5, (2, 8))))
>>> fig_, ax_ = metric.plot(values)
"""
return self._plot(val, ax)
Loading