diff --git a/CHANGELOG.md b/CHANGELOG.md index 0509f91fd72..ea434702767 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,6 +34,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [#1623](https://github.com/Lightning-AI/metrics/pull/1623), [#1638](https://github.com/Lightning-AI/metrics/pull/1638), [#1631](https://github.com/Lightning-AI/metrics/pull/1631), + [#1650](https://github.com/Lightning-AI/metrics/pull/1650), [#1639](https://github.com/Lightning-AI/metrics/pull/1639), [#1660](https://github.com/Lightning-AI/metrics/pull/1660) ) diff --git a/requirements/docs.txt b/requirements/docs.txt index 5131098333f..b6141996dfd 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -18,3 +18,5 @@ sphinx-copybutton>=0.3 -r detection.txt -r image.txt -r multimodal.txt +-r text.txt +-r text_test.txt diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index 0595575a647..d5276c694d0 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -581,7 +581,11 @@ def plot(self, *_: Any, **__: Any) -> Any: """Override this method plot the metric value.""" raise NotImplementedError - def _plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: + def _plot( + self, + val: Optional[Union[Tensor, Sequence[Tensor], Dict[str, Tensor], Sequence[Dict[str, Tensor]]]] = None, + ax: Optional[_AX_TYPE] = None, + ) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. Args: diff --git a/src/torchmetrics/text/bert.py b/src/torchmetrics/text/bert.py index b32c0e0cc75..97ea38a8f8a 100644 --- a/src/torchmetrics/text/bert.py +++ b/src/torchmetrics/text/bert.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Union from warnings import warn import torch @@ -23,7 +23,11 @@ from torchmetrics.functional.text.helper_embedding_metric import _preprocess_text from torchmetrics.metric import Metric from torchmetrics.utilities.checks import _SKIP_SLOW_DOCTEST, _try_proceed_with_timeout -from torchmetrics.utilities.imports import _TRANSFORMERS_AVAILABLE +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TRANSFORMERS_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["BERTScore.plot"] # Default model recommended in the original implementation. _DEFAULT_MODEL = "roberta-large" @@ -37,9 +41,9 @@ def _download_model() -> None: AutoModel.from_pretrained(_DEFAULT_MODEL) if _SKIP_SLOW_DOCTEST and not _try_proceed_with_timeout(_download_model): - __doctest_skip__ = ["BERTScore"] + __doctest_skip__ = ["BERTScore", "BERTScore.plot"] else: - __doctest_skip__ = ["BERTScore"] + __doctest_skip__ = ["BERTScore", "BERTScore.plot"] def _get_input_dict(input_ids: List[Tensor], attention_mask: List[Tensor]) -> Dict[str, Tensor]: @@ -240,3 +244,51 @@ def compute(self) -> Dict[str, Union[List[float], str]]: baseline_path=self.baseline_path, baseline_url=self.baseline_url, ) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> from torchmetrics.text.bert import BERTScore + >>> metric = BERTScore() + >>> preds = ["hello there", "general kenobi"] + >>> target = ["hello there", "master kenobi"] + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> from torchmetrics.text.bert import BERTScore + >>> metric = BERTScore() + >>> preds = ["hello there", "general kenobi"] + >>> target = ["hello there", "master kenobi"] + >>> values = [ ] + >>> for _ in range(10): + ... val = metric(preds, target) + ... val = {k: torch.tensor(v).mean() for k,v in val.items()} # convert into single value per key + ... values.append(val) + >>> fig_, ax_ = metric.plot(values) + """ + if val is None: # default average score across sentences + val = self.compute() # type: ignore + val = {k: torch.tensor(v).mean() for k, v in val.items()} # type: ignore + return self._plot(val, ax) diff --git a/src/torchmetrics/text/bleu.py b/src/torchmetrics/text/bleu.py index 462b4b347c9..f98cdfcbbaa 100644 --- a/src/torchmetrics/text/bleu.py +++ b/src/torchmetrics/text/bleu.py @@ -16,13 +16,18 @@ # Authors: torchtext authors and @sluks # Date: 2020-07-18 # Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score -from typing import Any, Optional, Sequence +from typing import Any, Optional, Sequence, Union import torch from torch import Tensor, tensor from torchmetrics import Metric from torchmetrics.functional.text.bleu import _bleu_score_compute, _bleu_score_update, _tokenize_fn +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["BLEUScore.plot"] class BLEUScore(Metric): @@ -103,3 +108,46 @@ def compute(self) -> Tensor: return _bleu_score_compute( self.preds_len, self.target_len, self.numerator, self.denominator, self.n_gram, self.weights, self.smooth ) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> from torchmetrics import BLEUScore + >>> metric = BLEUScore() + >>> preds = ['the cat is on the mat'] + >>> target = [['there is a cat on the mat', 'a cat is on the mat']] + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> from torchmetrics import BLEUScore + >>> metric = BLEUScore() + >>> preds = ['the cat is on the mat'] + >>> target = [['there is a cat on the mat', 'a cat is on the mat']] + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(preds, target)) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/text/chrf.py b/src/torchmetrics/text/chrf.py index 58228e93462..48061db6254 100644 --- a/src/torchmetrics/text/chrf.py +++ b/src/torchmetrics/text/chrf.py @@ -25,6 +25,12 @@ from torchmetrics import Metric from torchmetrics.functional.text.chrf import _chrf_score_compute, _chrf_score_update, _prepare_n_grams_dicts +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["CHRFScore.plot"] + _N_GRAM_LEVELS = ("char", "word") _TEXT_LEVELS = ("preds", "target", "matching") @@ -197,3 +203,46 @@ def _get_state_name(text: str, n_gram_level: str, n: int) -> str: def _get_text_n_gram_iterator(self) -> Iterator[Tuple[Tuple[str, int], str]]: """Get iterator over char/word and reference/hypothesis/matching n-gram level.""" return itertools.product(zip(_N_GRAM_LEVELS, [self.n_char_order, self.n_word_order]), _TEXT_LEVELS) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> from torchmetrics import CHRFScore + >>> metric = CHRFScore() + >>> preds = ['the cat is on the mat'] + >>> target = [['there is a cat on the mat', 'a cat is on the mat']] + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> from torchmetrics import CHRFScore + >>> metric = CHRFScore() + >>> preds = ['the cat is on the mat'] + >>> target = [['there is a cat on the mat', 'a cat is on the mat']] + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(preds, target)) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/text/infolm.py b/src/torchmetrics/text/infolm.py index e7c5abb8f69..d5e0e09adc3 100644 --- a/src/torchmetrics/text/infolm.py +++ b/src/torchmetrics/text/infolm.py @@ -28,10 +28,14 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat -from torchmetrics.utilities.imports import _TRANSFORMERS_AVAILABLE +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TRANSFORMERS_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["InfoLM.plot"] if not _TRANSFORMERS_AVAILABLE: - __doctest_skip__ = ["InfoLM"] + __doctest_skip__ = ["InfoLM", "InfoLM.plot"] class InfoLM(Metric): @@ -193,3 +197,46 @@ def compute(self) -> Union[Tensor, Tuple[Tensor, Tensor]]: return info_lm_score.mean(), info_lm_score return info_lm_score.mean() + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> from torchmetrics.text.infolm import InfoLM + >>> metric = InfoLM('google/bert_uncased_L-2_H-128_A-2', idf=False) + >>> preds = ['he read the book because he was interested in world history'] + >>> target = ['he was interested in world history because he read the book'] + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> from torchmetrics.text.infolm import InfoLM + >>> metric = InfoLM('google/bert_uncased_L-2_H-128_A-2', idf=False) + >>> preds = ["this is the prediction", "there is an other sample"] + >>> target = ["this is the reference", "there is another one"] + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(preds, target)) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/text/perplexity.py b/src/torchmetrics/text/perplexity.py index e7394450056..ed1a6bbd0af 100644 --- a/src/torchmetrics/text/perplexity.py +++ b/src/torchmetrics/text/perplexity.py @@ -12,12 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Sequence, Union from torch import Tensor, tensor from torchmetrics.functional.text.perplexity import _perplexity_compute, _perplexity_update from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["Perplexity.plot"] class Perplexity(Metric): @@ -42,6 +47,7 @@ class Perplexity(Metric): Additional keyword arguments, see :ref:`Metric kwargs` for more info. Examples: + >>> from torchmetrics import Perplexity >>> import torch >>> preds = torch.rand(2, 8, 5, generator=torch.manual_seed(22)) >>> target = torch.randint(5, (2, 8), generator=torch.manual_seed(22)) @@ -77,3 +83,44 @@ def update(self, preds: Tensor, target: Tensor) -> None: def compute(self) -> Tensor: """Compute the Perplexity.""" return _perplexity_compute(self.total_log_probs, self.count) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics import Perplexity + >>> metric = Perplexity() + >>> metric.update(torch.rand(2, 8, 5), torch.randint(5, (2, 8))) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics import Perplexity + >>> metric = Perplexity() + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(torch.rand(2, 8, 5), torch.randint(5, (2, 8)))) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/text/rouge.py b/src/torchmetrics/text/rouge.py index 9bb60a47df2..99532e7588b 100644 --- a/src/torchmetrics/text/rouge.py +++ b/src/torchmetrics/text/rouge.py @@ -23,7 +23,12 @@ _rouge_score_compute, _rouge_score_update, ) -from torchmetrics.utilities.imports import _NLTK_AVAILABLE +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _NLTK_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["ROUGEScore.plot"] + __doctest_requires__ = {("ROUGEScore",): ["nltk"]} @@ -182,3 +187,46 @@ def __hash__(self) -> int: hash_vals.append(value) return hash(tuple(hash_vals)) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> from torchmetrics.text.rouge import ROUGEScore + >>> metric = ROUGEScore() + >>> preds = "My name is John" + >>> target = "Is your name John" + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> from torchmetrics.text.rouge import ROUGEScore + >>> metric = ROUGEScore() + >>> preds = "My name is John" + >>> target = "Is your name John" + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(preds, target)) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/text/sacre_bleu.py b/src/torchmetrics/text/sacre_bleu.py index 00a67fe3998..a7226387683 100644 --- a/src/torchmetrics/text/sacre_bleu.py +++ b/src/torchmetrics/text/sacre_bleu.py @@ -17,14 +17,20 @@ # Authors: torchtext authors and @sluks # Date: 2020-07-18 # Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score -from typing import Any, Optional, Sequence +from typing import Any, Optional, Sequence, Union +from torch import Tensor from typing_extensions import Literal from torchmetrics.functional.text.bleu import _bleu_score_update from torchmetrics.functional.text.sacre_bleu import _SacreBLEUTokenizer from torchmetrics.text.bleu import BLEUScore -from torchmetrics.utilities.imports import _REGEX_AVAILABLE +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _REGEX_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["SacreBLEUScore.plot"] + AVAILABLE_TOKENIZERS = ("none", "13a", "zh", "intl", "char") @@ -114,3 +120,46 @@ def update(self, preds: Sequence[str], target: Sequence[Sequence[str]]) -> None: self.n_gram, self.tokenizer, ) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> from torchmetrics import SacreBLEUScore + >>> metric = SacreBLEUScore() + >>> preds = ['the cat is on the mat'] + >>> target = [['there is a cat on the mat', 'a cat is on the mat']] + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> from torchmetrics import SacreBLEUScore + >>> metric = SacreBLEUScore() + >>> preds = ['the cat is on the mat'] + >>> target = [['there is a cat on the mat', 'a cat is on the mat']] + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(preds, target)) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/text/squad.py b/src/torchmetrics/text/squad.py index ea7b0eed987..c2aa9cf09d1 100644 --- a/src/torchmetrics/text/squad.py +++ b/src/torchmetrics/text/squad.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict +from typing import Any, Dict, Optional, Sequence, Union import torch from torch import Tensor @@ -24,6 +24,11 @@ _squad_input_check, _squad_update, ) +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["SQuAD.plot"] class SQuAD(Metric): @@ -113,3 +118,46 @@ def update(self, preds: PREDS_TYPE, target: TARGETS_TYPE) -> None: def compute(self) -> Dict[str, Tensor]: """Aggregate the F1 Score and Exact match for the batch.""" return _squad_compute(self.f1_score, self.exact_match, self.total) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> from torchmetrics import SQuAD + >>> metric = SQuAD() + >>> preds = [{"prediction_text": "1976", "id": "56e10a3be3433e1400422b22"}] + >>> target = [{"answers": {"answer_start": [97], "text": ["1976"]}, "id": "56e10a3be3433e1400422b22"}] + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> from torchmetrics import SQuAD + >>> metric = SQuAD() + >>> preds = [{"prediction_text": "1976", "id": "56e10a3be3433e1400422b22"}] + >>> target = [{"answers": {"answer_start": [97], "text": ["1976"]}, "id": "56e10a3be3433e1400422b22"}] + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(preds, target)) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/text/ter.py b/src/torchmetrics/text/ter.py index 5d1ffa211fc..35cc197790d 100644 --- a/src/torchmetrics/text/ter.py +++ b/src/torchmetrics/text/ter.py @@ -19,6 +19,11 @@ from torchmetrics.functional.text.ter import _ter_compute, _ter_update, _TercomTokenizer from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["TranslationEditRate.plot"] class TranslationEditRate(Metric): @@ -46,6 +51,7 @@ class TranslationEditRate(Metric): kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example: + >>> from torchmetrics import TranslationEditRate >>> preds = ['the cat is on the mat'] >>> target = [['there is a cat on the mat', 'a cat is on the mat']] >>> ter = TranslationEditRate() @@ -105,3 +111,46 @@ def compute(self) -> Union[Tensor, Tuple[Tensor, Tensor]]: if self.sentence_ter is not None: return ter, torch.cat(self.sentence_ter) return ter + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> from torchmetrics import TranslationEditRate + >>> metric = TranslationEditRate() + >>> preds = ['the cat is on the mat'] + >>> target = [['there is a cat on the mat', 'a cat is on the mat']] + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> from torchmetrics import TranslationEditRate + >>> metric = TranslationEditRate() + >>> preds = ['the cat is on the mat'] + >>> target = [['there is a cat on the mat', 'a cat is on the mat']] + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(preds, target)) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index 9e7422edc3a..55940fe85d5 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -141,9 +141,18 @@ RetrievalRPrecision, ) from torchmetrics.text import ( + BERTScore, + BLEUScore, CharErrorRate, + CHRFScore, ExtendedEditDistance, + InfoLM, MatchErrorRate, + Perplexity, + ROUGEScore, + SacreBLEUScore, + SQuAD, + TranslationEditRate, WordErrorRate, WordInfoLost, WordInfoPreserved, @@ -164,6 +173,8 @@ _nominal_input = lambda: torch.randint(0, 4, (100,)) _text_input_1 = lambda: ["this is the prediction", "there is an other sample"] _text_input_2 = lambda: ["this is the reference", "there is another one"] +_text_input_3 = lambda: ["the cat is on the mat"] +_text_input_4 = lambda: [["there is a cat on the mat", "a cat is on the mat"]] @pytest.mark.parametrize( @@ -546,9 +557,27 @@ pytest.param(CharErrorRate, _text_input_1, _text_input_2, id="character error rate"), pytest.param(ExtendedEditDistance, _text_input_1, _text_input_2, id="extended edit distance"), pytest.param(MatchErrorRate, _text_input_1, _text_input_2, id="match error rate"), + pytest.param(BLEUScore, _text_input_3, _text_input_4, id="bleu score"), + pytest.param(CHRFScore, _text_input_3, _text_input_4, id="bleu score"), + pytest.param( + partial(InfoLM, model_name_or_path="google/bert_uncased_L-2_H-128_A-2", idf=False), + _text_input_1, + _text_input_2, + id="info lm", + ), + pytest.param(Perplexity, lambda: torch.rand(2, 8, 5), lambda: torch.randint(5, (2, 8)), id="perplexity"), + pytest.param(ROUGEScore, lambda: "My name is John", lambda: "Is your name John", id="rouge score"), + pytest.param(SacreBLEUScore, _text_input_3, _text_input_4, id="sacre bleu score"), + pytest.param( + SQuAD, + lambda: [{"prediction_text": "1976", "id": "56e10a3be3433e1400422b22"}], + lambda: [{"answers": {"answer_start": [97], "text": ["1976"]}, "id": "56e10a3be3433e1400422b22"}], + id="squad", + ), + pytest.param(TranslationEditRate, _text_input_3, _text_input_4, id="translation edit rate"), ], ) -@pytest.mark.parametrize("num_vals", [1, 5]) +@pytest.mark.parametrize("num_vals", [1, 3]) def test_plot_methods(metric_class: object, preds: Callable, target: Callable, num_vals: int): """Test the plot method of metrics that only output a single tensor scalar.""" metric = metric_class() @@ -626,6 +655,17 @@ def test_plot_methods_special_image_metrics(metric_class, preds, target, index_0 assert isinstance(ax, matplotlib.axes.Axes) +@pytest.mark.skipif(not hasattr(torch, "inference_mode"), reason="`inference_mode` is not supported") +def test_plot_methods_special_text_metrics(): + """Test the plot method for text metrics that does not fit the default testing format.""" + metric = BERTScore() + with torch.inference_mode(): + metric.update(_text_input_1(), _text_input_2()) + fig, ax = metric.plot() + assert isinstance(fig, plt.Figure) + assert isinstance(ax, matplotlib.axes.Axes) + + @pytest.mark.parametrize( ("metric_class", "preds", "target", "indexes"), [