From d42bdbab5f892aebc733905d3d94543831e6ae60 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Fri, 24 Mar 2023 11:42:18 +0100 Subject: [PATCH 01/10] impl + doctests --- src/torchmetrics/text/bert.py | 60 +++++++++++++++++++++++++++-- src/torchmetrics/text/bleu.py | 50 +++++++++++++++++++++++- src/torchmetrics/text/chrf.py | 49 +++++++++++++++++++++++ src/torchmetrics/text/infolm.py | 51 +++++++++++++++++++++++- src/torchmetrics/text/perplexity.py | 49 ++++++++++++++++++++++- src/torchmetrics/text/rouge.py | 50 +++++++++++++++++++++++- src/torchmetrics/text/sacre_bleu.py | 53 ++++++++++++++++++++++++- src/torchmetrics/text/squad.py | 50 +++++++++++++++++++++++- src/torchmetrics/text/ter.py | 49 +++++++++++++++++++++++ 9 files changed, 449 insertions(+), 12 deletions(-) diff --git a/src/torchmetrics/text/bert.py b/src/torchmetrics/text/bert.py index b32c0e0cc75..997906f2637 100644 --- a/src/torchmetrics/text/bert.py +++ b/src/torchmetrics/text/bert.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Union from warnings import warn import torch @@ -23,7 +23,11 @@ from torchmetrics.functional.text.helper_embedding_metric import _preprocess_text from torchmetrics.metric import Metric from torchmetrics.utilities.checks import _SKIP_SLOW_DOCTEST, _try_proceed_with_timeout -from torchmetrics.utilities.imports import _TRANSFORMERS_AVAILABLE +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TRANSFORMERS_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["BERTScore.plot"] # Default model recommended in the original implementation. _DEFAULT_MODEL = "roberta-large" @@ -37,9 +41,9 @@ def _download_model() -> None: AutoModel.from_pretrained(_DEFAULT_MODEL) if _SKIP_SLOW_DOCTEST and not _try_proceed_with_timeout(_download_model): - __doctest_skip__ = ["BERTScore"] + __doctest_skip__ = ["BERTScore", "BERTScore.plot"] else: - __doctest_skip__ = ["BERTScore"] + __doctest_skip__ = ["BERTScore", "BERTScore.plot"] def _get_input_dict(input_ids: List[Tensor], attention_mask: List[Tensor]) -> Dict[str, Tensor]: @@ -240,3 +244,51 @@ def compute(self) -> Dict[str, Union[List[float], str]]: baseline_path=self.baseline_path, baseline_url=self.baseline_url, ) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> from torchmetrics.text.bert import BERTScore + >>> metric = BERTScore() + >>> preds = ["hello there", "general kenobi"] + >>> target = ["hello there", "master kenobi"] + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> from torchmetrics.text.bert import BERTScore + >>> metric = BERTScore() + >>> preds = ["hello there", "general kenobi"] + >>> target = ["hello there", "master kenobi"] + >>> values = [ ] + >>> for _ in range(10): + ... val = metric(preds, target) + ... val = {k: torch.tensor(v).mean() for k,v in val.items()} # convert into single value per key + ... values.append(val) + >>> fig_, ax_ = metric.plot(values) + """ + if val is None: # default average score across sentences + val = self.compute() + val = {k: torch.tensor(v).mean() for k, v in val.items()} + return self._plot(val, ax) diff --git a/src/torchmetrics/text/bleu.py b/src/torchmetrics/text/bleu.py index 462b4b347c9..f98cdfcbbaa 100644 --- a/src/torchmetrics/text/bleu.py +++ b/src/torchmetrics/text/bleu.py @@ -16,13 +16,18 @@ # Authors: torchtext authors and @sluks # Date: 2020-07-18 # Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score -from typing import Any, Optional, Sequence +from typing import Any, Optional, Sequence, Union import torch from torch import Tensor, tensor from torchmetrics import Metric from torchmetrics.functional.text.bleu import _bleu_score_compute, _bleu_score_update, _tokenize_fn +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["BLEUScore.plot"] class BLEUScore(Metric): @@ -103,3 +108,46 @@ def compute(self) -> Tensor: return _bleu_score_compute( self.preds_len, self.target_len, self.numerator, self.denominator, self.n_gram, self.weights, self.smooth ) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> from torchmetrics import BLEUScore + >>> metric = BLEUScore() + >>> preds = ['the cat is on the mat'] + >>> target = [['there is a cat on the mat', 'a cat is on the mat']] + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> from torchmetrics import BLEUScore + >>> metric = BLEUScore() + >>> preds = ['the cat is on the mat'] + >>> target = [['there is a cat on the mat', 'a cat is on the mat']] + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(preds, target)) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/text/chrf.py b/src/torchmetrics/text/chrf.py index 58228e93462..48061db6254 100644 --- a/src/torchmetrics/text/chrf.py +++ b/src/torchmetrics/text/chrf.py @@ -25,6 +25,12 @@ from torchmetrics import Metric from torchmetrics.functional.text.chrf import _chrf_score_compute, _chrf_score_update, _prepare_n_grams_dicts +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["CHRFScore.plot"] + _N_GRAM_LEVELS = ("char", "word") _TEXT_LEVELS = ("preds", "target", "matching") @@ -197,3 +203,46 @@ def _get_state_name(text: str, n_gram_level: str, n: int) -> str: def _get_text_n_gram_iterator(self) -> Iterator[Tuple[Tuple[str, int], str]]: """Get iterator over char/word and reference/hypothesis/matching n-gram level.""" return itertools.product(zip(_N_GRAM_LEVELS, [self.n_char_order, self.n_word_order]), _TEXT_LEVELS) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> from torchmetrics import CHRFScore + >>> metric = CHRFScore() + >>> preds = ['the cat is on the mat'] + >>> target = [['there is a cat on the mat', 'a cat is on the mat']] + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> from torchmetrics import CHRFScore + >>> metric = CHRFScore() + >>> preds = ['the cat is on the mat'] + >>> target = [['there is a cat on the mat', 'a cat is on the mat']] + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(preds, target)) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/text/infolm.py b/src/torchmetrics/text/infolm.py index e7c5abb8f69..d5e0e09adc3 100644 --- a/src/torchmetrics/text/infolm.py +++ b/src/torchmetrics/text/infolm.py @@ -28,10 +28,14 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat -from torchmetrics.utilities.imports import _TRANSFORMERS_AVAILABLE +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TRANSFORMERS_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["InfoLM.plot"] if not _TRANSFORMERS_AVAILABLE: - __doctest_skip__ = ["InfoLM"] + __doctest_skip__ = ["InfoLM", "InfoLM.plot"] class InfoLM(Metric): @@ -193,3 +197,46 @@ def compute(self) -> Union[Tensor, Tuple[Tensor, Tensor]]: return info_lm_score.mean(), info_lm_score return info_lm_score.mean() + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> from torchmetrics.text.infolm import InfoLM + >>> metric = InfoLM('google/bert_uncased_L-2_H-128_A-2', idf=False) + >>> preds = ['he read the book because he was interested in world history'] + >>> target = ['he was interested in world history because he read the book'] + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> from torchmetrics.text.infolm import InfoLM + >>> metric = InfoLM('google/bert_uncased_L-2_H-128_A-2', idf=False) + >>> preds = ["this is the prediction", "there is an other sample"] + >>> target = ["this is the reference", "there is another one"] + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(preds, target)) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/text/perplexity.py b/src/torchmetrics/text/perplexity.py index e7394450056..ed1a6bbd0af 100644 --- a/src/torchmetrics/text/perplexity.py +++ b/src/torchmetrics/text/perplexity.py @@ -12,12 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Sequence, Union from torch import Tensor, tensor from torchmetrics.functional.text.perplexity import _perplexity_compute, _perplexity_update from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["Perplexity.plot"] class Perplexity(Metric): @@ -42,6 +47,7 @@ class Perplexity(Metric): Additional keyword arguments, see :ref:`Metric kwargs` for more info. Examples: + >>> from torchmetrics import Perplexity >>> import torch >>> preds = torch.rand(2, 8, 5, generator=torch.manual_seed(22)) >>> target = torch.randint(5, (2, 8), generator=torch.manual_seed(22)) @@ -77,3 +83,44 @@ def update(self, preds: Tensor, target: Tensor) -> None: def compute(self) -> Tensor: """Compute the Perplexity.""" return _perplexity_compute(self.total_log_probs, self.count) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics import Perplexity + >>> metric = Perplexity() + >>> metric.update(torch.rand(2, 8, 5), torch.randint(5, (2, 8))) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics import Perplexity + >>> metric = Perplexity() + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(torch.rand(2, 8, 5), torch.randint(5, (2, 8)))) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/text/rouge.py b/src/torchmetrics/text/rouge.py index 9bb60a47df2..aa4917781cc 100644 --- a/src/torchmetrics/text/rouge.py +++ b/src/torchmetrics/text/rouge.py @@ -23,7 +23,12 @@ _rouge_score_compute, _rouge_score_update, ) -from torchmetrics.utilities.imports import _NLTK_AVAILABLE +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _NLTK_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["CharErrorRate.plot"] + __doctest_requires__ = {("ROUGEScore",): ["nltk"]} @@ -182,3 +187,46 @@ def __hash__(self) -> int: hash_vals.append(value) return hash(tuple(hash_vals)) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> from torchmetrics.text.rouge import ROUGEScore + >>> metric = ROUGEScore() + >>> preds = "My name is John" + >>> target = "Is your name John" + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> from torchmetrics.text.rouge import ROUGEScore + >>> metric = ROUGEScore() + >>> preds = "My name is John" + >>> target = "Is your name John" + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(preds, target)) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/text/sacre_bleu.py b/src/torchmetrics/text/sacre_bleu.py index 00a67fe3998..a7226387683 100644 --- a/src/torchmetrics/text/sacre_bleu.py +++ b/src/torchmetrics/text/sacre_bleu.py @@ -17,14 +17,20 @@ # Authors: torchtext authors and @sluks # Date: 2020-07-18 # Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score -from typing import Any, Optional, Sequence +from typing import Any, Optional, Sequence, Union +from torch import Tensor from typing_extensions import Literal from torchmetrics.functional.text.bleu import _bleu_score_update from torchmetrics.functional.text.sacre_bleu import _SacreBLEUTokenizer from torchmetrics.text.bleu import BLEUScore -from torchmetrics.utilities.imports import _REGEX_AVAILABLE +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _REGEX_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["SacreBLEUScore.plot"] + AVAILABLE_TOKENIZERS = ("none", "13a", "zh", "intl", "char") @@ -114,3 +120,46 @@ def update(self, preds: Sequence[str], target: Sequence[Sequence[str]]) -> None: self.n_gram, self.tokenizer, ) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> from torchmetrics import SacreBLEUScore + >>> metric = SacreBLEUScore() + >>> preds = ['the cat is on the mat'] + >>> target = [['there is a cat on the mat', 'a cat is on the mat']] + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> from torchmetrics import SacreBLEUScore + >>> metric = SacreBLEUScore() + >>> preds = ['the cat is on the mat'] + >>> target = [['there is a cat on the mat', 'a cat is on the mat']] + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(preds, target)) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/text/squad.py b/src/torchmetrics/text/squad.py index ea7b0eed987..2f1e16923a3 100644 --- a/src/torchmetrics/text/squad.py +++ b/src/torchmetrics/text/squad.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict +from typing import Any, Dict, Optional, Sequence, Union import torch from torch import Tensor @@ -24,6 +24,11 @@ _squad_input_check, _squad_update, ) +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["CharErrorRate.plot"] class SQuAD(Metric): @@ -113,3 +118,46 @@ def update(self, preds: PREDS_TYPE, target: TARGETS_TYPE) -> None: def compute(self) -> Dict[str, Tensor]: """Aggregate the F1 Score and Exact match for the batch.""" return _squad_compute(self.f1_score, self.exact_match, self.total) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> from torchmetrics import SQuAD + >>> metric = SQuAD() + >>> preds = [{"prediction_text": "1976", "id": "56e10a3be3433e1400422b22"}] + >>> target = [{"answers": {"answer_start": [97], "text": ["1976"]}, "id": "56e10a3be3433e1400422b22"}] + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> from torchmetrics import SQuAD + >>> metric = SQuAD() + >>> preds = [{"prediction_text": "1976", "id": "56e10a3be3433e1400422b22"}] + >>> target = [{"answers": {"answer_start": [97], "text": ["1976"]}, "id": "56e10a3be3433e1400422b22"}] + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(preds, target)) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/text/ter.py b/src/torchmetrics/text/ter.py index 5d1ffa211fc..35cc197790d 100644 --- a/src/torchmetrics/text/ter.py +++ b/src/torchmetrics/text/ter.py @@ -19,6 +19,11 @@ from torchmetrics.functional.text.ter import _ter_compute, _ter_update, _TercomTokenizer from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["TranslationEditRate.plot"] class TranslationEditRate(Metric): @@ -46,6 +51,7 @@ class TranslationEditRate(Metric): kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example: + >>> from torchmetrics import TranslationEditRate >>> preds = ['the cat is on the mat'] >>> target = [['there is a cat on the mat', 'a cat is on the mat']] >>> ter = TranslationEditRate() @@ -105,3 +111,46 @@ def compute(self) -> Union[Tensor, Tuple[Tensor, Tensor]]: if self.sentence_ter is not None: return ter, torch.cat(self.sentence_ter) return ter + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> from torchmetrics import TranslationEditRate + >>> metric = TranslationEditRate() + >>> preds = ['the cat is on the mat'] + >>> target = [['there is a cat on the mat', 'a cat is on the mat']] + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> from torchmetrics import TranslationEditRate + >>> metric = TranslationEditRate() + >>> preds = ['the cat is on the mat'] + >>> target = [['there is a cat on the mat', 'a cat is on the mat']] + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(preds, target)) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) From 500b891d2f1952da76383bf68cef2d3e2ae9d86b Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Fri, 24 Mar 2023 11:43:14 +0100 Subject: [PATCH 02/10] requirements --- requirements/docs.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements/docs.txt b/requirements/docs.txt index 1209e6d16b0..b8eaedfe27f 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -17,3 +17,4 @@ sphinx-copybutton>=0.3 -r audio.txt -r detection.txt -r image.txt +-r text.txt From f24432db3250639e4369da6a5e81634bb56ba97e Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Fri, 24 Mar 2023 11:44:23 +0100 Subject: [PATCH 03/10] tests --- tests/unittests/utilities/test_plot.py | 64 +++++++++++++++++++++++++- 1 file changed, 63 insertions(+), 1 deletion(-) diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index fa25c215840..95708ebeb9e 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -129,9 +129,18 @@ RetrievalRPrecision, ) from torchmetrics.text import ( + BERTScore, + BLEUScore, CharErrorRate, + CHRFScore, ExtendedEditDistance, + InfoLM, MatchErrorRate, + Perplexity, + ROUGEScore, + SacreBLEUScore, + SQuAD, + TranslationEditRate, WordErrorRate, WordInfoLost, WordInfoPreserved, @@ -151,6 +160,8 @@ _nominal_input = lambda: torch.randint(0, 4, (100,)) _text_input_1 = lambda: ["this is the prediction", "there is an other sample"] _text_input_2 = lambda: ["this is the reference", "there is another one"] +_text_input_3 = lambda: ["the cat is on the mat"] +_text_input_4 = lambda: [["there is a cat on the mat", "a cat is on the mat"]] @pytest.mark.parametrize( @@ -495,9 +506,27 @@ pytest.param(CharErrorRate, _text_input_1, _text_input_2, id="character error rate"), pytest.param(ExtendedEditDistance, _text_input_1, _text_input_2, id="extended edit distance"), pytest.param(MatchErrorRate, _text_input_1, _text_input_2, id="match error rate"), + pytest.param(BLEUScore, _text_input_3, _text_input_4, id="bleu score"), + pytest.param(CHRFScore, _text_input_3, _text_input_4, id="bleu score"), + pytest.param( + partial(InfoLM, model_name_or_path="google/bert_uncased_L-2_H-128_A-2", idf=False), + _text_input_1, + _text_input_2, + id="info lm", + ), + pytest.param(Perplexity, lambda: torch.rand(2, 8, 5), lambda: torch.randint(5, (2, 8)), id="perplexity"), + pytest.param(ROUGEScore, lambda: "My name is John", lambda: "Is your name John", id="rouge score"), + pytest.param(SacreBLEUScore, _text_input_3, _text_input_4, id="sacre bleu score"), + pytest.param( + SQuAD, + lambda: [{"prediction_text": "1976", "id": "56e10a3be3433e1400422b22"}], + lambda: [{"answers": {"answer_start": [97], "text": ["1976"]}, "id": "56e10a3be3433e1400422b22"}], + id="squad", + ), + pytest.param(TranslationEditRate, _text_input_3, _text_input_4, id="translation edit rate"), ], ) -@pytest.mark.parametrize("num_vals", [1, 5]) +@pytest.mark.parametrize("num_vals", [1, 3]) def test_plot_methods(metric_class: object, preds: Callable, target: Callable, num_vals: int): """Test the plot method of metrics that only output a single tensor scalar.""" metric = metric_class() @@ -575,6 +604,39 @@ def test_plot_methods_special_image_metrics(metric_class, preds, target, index_0 assert isinstance(ax, matplotlib.axes.Axes) +@pytest.mark.parametrize( + ("metric_class", "preds", "target", "transform"), + [ + pytest.param( + BERTScore, + _text_input_1, + _text_input_2, + lambda d: {k: torch.tensor(v).mean() for k, v in d.items()}, + id="bert score", + ) + ], +) +@pytest.mark.parametrize("num_vals", [1, 2]) +def test_plot_methods_special_text_metrics( + metric_class: object, preds: Callable, target: Callable, transform: Callable, num_vals: int +): + """Test the plot method for text metrics that does not fit the default testing format.""" + metric = metric_class() + + if num_vals == 1: + metric.update(preds(), target()) + fig, ax = metric.plot() + else: + vals = [] + for _ in range(num_vals): + val = metric(preds(), target()) + vals.append(transform(val)) + fig, ax = metric.plot(vals) + + assert isinstance(fig, plt.Figure) + assert isinstance(ax, matplotlib.axes.Axes) + + @pytest.mark.parametrize( ("metric_class", "preds", "target", "indexes"), [ From 57d837319294a52cb90a73d17d812d0c839eb11c Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Fri, 24 Mar 2023 11:46:23 +0100 Subject: [PATCH 04/10] changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index afefa880f8e..cd5f219d546 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,6 +34,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [#1623](https://github.com/Lightning-AI/metrics/pull/1623), [#1638](https://github.com/Lightning-AI/metrics/pull/1638), [#1631](https://github.com/Lightning-AI/metrics/pull/1631), + [#1650](https://github.com/Lightning-AI/metrics/pull/1650), ) From e428835499373d8387fbc6ec5267b880de1496b4 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 28 Mar 2023 16:55:40 +0200 Subject: [PATCH 05/10] fix --- src/torchmetrics/text/rouge.py | 2 +- src/torchmetrics/text/squad.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/text/rouge.py b/src/torchmetrics/text/rouge.py index aa4917781cc..99532e7588b 100644 --- a/src/torchmetrics/text/rouge.py +++ b/src/torchmetrics/text/rouge.py @@ -27,7 +27,7 @@ from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE if not _MATPLOTLIB_AVAILABLE: - __doctest_skip__ = ["CharErrorRate.plot"] + __doctest_skip__ = ["ROUGEScore.plot"] __doctest_requires__ = {("ROUGEScore",): ["nltk"]} diff --git a/src/torchmetrics/text/squad.py b/src/torchmetrics/text/squad.py index 2f1e16923a3..c2aa9cf09d1 100644 --- a/src/torchmetrics/text/squad.py +++ b/src/torchmetrics/text/squad.py @@ -28,7 +28,7 @@ from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE if not _MATPLOTLIB_AVAILABLE: - __doctest_skip__ = ["CharErrorRate.plot"] + __doctest_skip__ = ["SQuAD.plot"] class SQuAD(Metric): From d55d582a0c9a58fd6c52688680cde6303272829a Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 28 Mar 2023 17:06:14 +0200 Subject: [PATCH 06/10] fix --- src/torchmetrics/text/bert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/text/bert.py b/src/torchmetrics/text/bert.py index 997906f2637..0e9c55d6bb6 100644 --- a/src/torchmetrics/text/bert.py +++ b/src/torchmetrics/text/bert.py @@ -291,4 +291,4 @@ def plot( if val is None: # default average score across sentences val = self.compute() val = {k: torch.tensor(v).mean() for k, v in val.items()} - return self._plot(val, ax) + return self._plot(val, ax) # type: ignore From 0ea9c0be6828cf388d5252b79a67255a7e484e42 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 28 Mar 2023 17:17:56 +0200 Subject: [PATCH 07/10] fix --- src/torchmetrics/metric.py | 6 +++++- src/torchmetrics/text/bert.py | 6 +++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index 0595575a647..d5276c694d0 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -581,7 +581,11 @@ def plot(self, *_: Any, **__: Any) -> Any: """Override this method plot the metric value.""" raise NotImplementedError - def _plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: + def _plot( + self, + val: Optional[Union[Tensor, Sequence[Tensor], Dict[str, Tensor], Sequence[Dict[str, Tensor]]]] = None, + ax: Optional[_AX_TYPE] = None, + ) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. Args: diff --git a/src/torchmetrics/text/bert.py b/src/torchmetrics/text/bert.py index 0e9c55d6bb6..97ea38a8f8a 100644 --- a/src/torchmetrics/text/bert.py +++ b/src/torchmetrics/text/bert.py @@ -289,6 +289,6 @@ def plot( >>> fig_, ax_ = metric.plot(values) """ if val is None: # default average score across sentences - val = self.compute() - val = {k: torch.tensor(v).mean() for k, v in val.items()} - return self._plot(val, ax) # type: ignore + val = self.compute() # type: ignore + val = {k: torch.tensor(v).mean() for k, v in val.items()} # type: ignore + return self._plot(val, ax) From dfed9918ef23160c0183ff972f0612d855ba9a18 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 29 Mar 2023 08:28:24 +0200 Subject: [PATCH 08/10] fix --- tests/unittests/utilities/test_plot.py | 33 ++++---------------------- 1 file changed, 5 insertions(+), 28 deletions(-) diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index 95708ebeb9e..627e2371b57 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -604,35 +604,12 @@ def test_plot_methods_special_image_metrics(metric_class, preds, target, index_0 assert isinstance(ax, matplotlib.axes.Axes) -@pytest.mark.parametrize( - ("metric_class", "preds", "target", "transform"), - [ - pytest.param( - BERTScore, - _text_input_1, - _text_input_2, - lambda d: {k: torch.tensor(v).mean() for k, v in d.items()}, - id="bert score", - ) - ], -) -@pytest.mark.parametrize("num_vals", [1, 2]) -def test_plot_methods_special_text_metrics( - metric_class: object, preds: Callable, target: Callable, transform: Callable, num_vals: int -): +@torch.inference_mode() +def test_plot_methods_special_text_metrics(): """Test the plot method for text metrics that does not fit the default testing format.""" - metric = metric_class() - - if num_vals == 1: - metric.update(preds(), target()) - fig, ax = metric.plot() - else: - vals = [] - for _ in range(num_vals): - val = metric(preds(), target()) - vals.append(transform(val)) - fig, ax = metric.plot(vals) - + metric = BERTScore() + metric.update(_text_input_1(), _text_input_2()) + fig, ax = metric.plot() assert isinstance(fig, plt.Figure) assert isinstance(ax, matplotlib.axes.Axes) From 19103fabaf0a826ed14c67415df466f32a797579 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 29 Mar 2023 08:48:03 +0200 Subject: [PATCH 09/10] fix --- requirements/docs.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements/docs.txt b/requirements/docs.txt index b8eaedfe27f..68e04b60541 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -18,3 +18,4 @@ sphinx-copybutton>=0.3 -r detection.txt -r image.txt -r text.txt +-r text_test.txt From e59c6ebae46b913b25729f61ac018dc6abd2142f Mon Sep 17 00:00:00 2001 From: Jirka Date: Fri, 31 Mar 2023 08:20:30 +0200 Subject: [PATCH 10/10] inference_mode --- tests/unittests/utilities/test_plot.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index a12c9aa4264..b25e516fc8e 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -623,12 +623,13 @@ def test_plot_methods_special_image_metrics(metric_class, preds, target, index_0 assert isinstance(ax, matplotlib.axes.Axes) -@torch.inference_mode() +@pytest.mark.skipif(not hasattr(torch, "inference_mode"), reason="`inference_mode` is not supported") def test_plot_methods_special_text_metrics(): """Test the plot method for text metrics that does not fit the default testing format.""" metric = BERTScore() - metric.update(_text_input_1(), _text_input_2()) - fig, ax = metric.plot() + with torch.inference_mode(): + metric.update(_text_input_1(), _text_input_2()) + fig, ax = metric.plot() assert isinstance(fig, plt.Figure) assert isinstance(ax, matplotlib.axes.Axes)