diff --git a/CHANGELOG.md b/CHANGELOG.md index 5ae89562ec1..cb71dce38fb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,10 +11,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- Added global option `sync_on_compute` to disable automatic syncronization when `compute` is called ([#1107](https://github.dev/Lightning-AI/metrics/pull/1107)) +- Added `Perplexity` metric ([#922](https://github.com/PyTorchLightning/metrics/pull/922)) -- +- Added global option `sync_on_compute` to disable automatic syncronization when `compute` is called ([#1107](https://github.dev/Lightning-AI/metrics/pull/1107)) ### Changed diff --git a/docs/source/text/perplexity.rst b/docs/source/text/perplexity.rst new file mode 100644 index 00000000000..966bf00e560 --- /dev/null +++ b/docs/source/text/perplexity.rst @@ -0,0 +1,22 @@ +.. customcarditem:: + :header: Perplexity + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/summarization.svg + :tags: Text + +.. include:: ../links.rst + +########## +Perplexity +########## + +Module Interface +________________ + +.. autoclass:: torchmetrics.text.perplexity.Perplexity + :noindex: + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.text.perplexity.perplexity + :noindex: diff --git a/src/torchmetrics/__init__.py b/src/torchmetrics/__init__.py index 549367ce4da..483a4d786b7 100644 --- a/src/torchmetrics/__init__.py +++ b/src/torchmetrics/__init__.py @@ -92,6 +92,7 @@ CHRFScore, ExtendedEditDistance, MatchErrorRate, + Perplexity, SacreBLEUScore, SQuAD, TranslationEditRate, @@ -157,6 +158,7 @@ "MultiScaleStructuralSimilarityIndexMeasure", "PearsonCorrCoef", "PermutationInvariantTraining", + "Perplexity", "Precision", "PrecisionRecallCurve", "PeakSignalNoiseRatio", diff --git a/src/torchmetrics/functional/__init__.py b/src/torchmetrics/functional/__init__.py index 08f0ac2e46c..3782433bb2e 100644 --- a/src/torchmetrics/functional/__init__.py +++ b/src/torchmetrics/functional/__init__.py @@ -78,6 +78,7 @@ from torchmetrics.functional.text.chrf import chrf_score from torchmetrics.functional.text.eed import extended_edit_distance from torchmetrics.functional.text.mer import match_error_rate +from torchmetrics.functional.text.perplexity import perplexity from torchmetrics.functional.text.rouge import rouge_score from torchmetrics.functional.text.sacre_bleu import sacre_bleu_score from torchmetrics.functional.text.squad import squad @@ -131,6 +132,7 @@ "pairwise_manhattan_distance", "pearson_corrcoef", "permutation_invariant_training", + "perplexity", "pit_permutate", "precision", "precision_recall", diff --git a/src/torchmetrics/functional/text/__init__.py b/src/torchmetrics/functional/text/__init__.py index 14b982f90eb..a03a0a099f3 100644 --- a/src/torchmetrics/functional/text/__init__.py +++ b/src/torchmetrics/functional/text/__init__.py @@ -17,6 +17,7 @@ from torchmetrics.functional.text.chrf import chrf_score # noqa: F401 from torchmetrics.functional.text.eed import extended_edit_distance # noqa: F401 from torchmetrics.functional.text.mer import match_error_rate # noqa: F401 +from torchmetrics.functional.text.perplexity import perplexity # noqa: F401 from torchmetrics.functional.text.sacre_bleu import sacre_bleu_score # noqa: F401 from torchmetrics.functional.text.squad import squad # noqa: F401 from torchmetrics.functional.text.ter import translation_edit_rate # noqa: F401 diff --git a/src/torchmetrics/functional/text/perplexity.py b/src/torchmetrics/functional/text/perplexity.py new file mode 100644 index 00000000000..8e614580761 --- /dev/null +++ b/src/torchmetrics/functional/text/perplexity.py @@ -0,0 +1,139 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import Tensor + +_TORCH_FLOAT_OR_DOUBLE = (torch.float32, torch.float64) + + +def _check_shape_and_type_consistency(preds: Tensor, target: Tensor) -> None: + """Check shape and type consistency of input vectors. + + Args: + preds: + Probabilities assigned to each token in a sequence with shape [batch_size, seq_len, vocab_size]. + target: + Ground truth values with a shape [batch_size, seq_len]. + + Raises: + ValueError: + If ``preds`` tensor has no 3 dimensions. + ValueError: + If ``target`` tensor has no 2 dimensions. + ValueError: + If the first two dimensions of ``preds`` and ``target`` do not equal. + TypeError: + If ``preds`` dtype is not one of ``(torch.float16, torch.float32, torch.float64)`` + TypeError: + If ``target`` is not of a type LongTensor (torch.int64) + """ + if len(preds.shape) != 3: + raise ValueError( + "Input tensor `preds` is expected to have 3 dimensions, [batch_size, seq_len, vocab_size]," + f" but got {len(preds.shape)}." + ) + if len(target.shape) != 2: + raise ValueError( + "Input tensor `target` is expected to have 2 dimensions, [batch_size, seq_len]," + f" but got {len(target.shape)}." + ) + if preds.shape[:2] != target.shape: + raise ValueError( + "Input tensors `preds` and `target` are expected to have equaling first two dimensions," + f" [batch_size, seq_len], but got {preds.shape[:2]} and {target.shape}." + ) + if preds.dtype not in _TORCH_FLOAT_OR_DOUBLE: + raise TypeError( + f"Input tensor `preds` is expected to be of a type one of {_TORCH_FLOAT_OR_DOUBLE} but got {preds.dtype}." + ) + if target.dtype != torch.int64: + raise TypeError(f"Input tensor `target` is expected to be of a type {torch.int64} but got {target.dtype}.") + + +def _perplexity_update(preds: Tensor, target: Tensor, ignore_index: Optional[int] = None) -> Tuple[Tensor, Tensor]: + """Compute intermediate statistics for Perplexity. + + Args: + preds: + Probabilities assigned to each token in a sequence with shape [batch_size, seq_len, vocab_size]. + target: + Ground truth values with a shape [batch_size, seq_len]. + ignore_index: + Integer specifying a target class to ignore. If given, this class index does not contribute + to the returned score. + + Returns: + Log probabilities, summed over all samples + Number of samples + """ + _check_shape_and_type_consistency(preds, target) + + probs = F.softmax(preds.reshape(-1, preds.shape[-1]), dim=1) + target = target.reshape(-1) + + if ignore_index is not None: + mask = target.ne(ignore_index) + target = target.where(target != ignore_index, torch.tensor(0, device=target.device)) + else: + mask = torch.ones_like(target, dtype=torch.bool) + + probs = probs[:, target].diagonal()[mask] + total_log_probs = -probs.log().sum() + count = mask.sum() + + return total_log_probs, count + + +def _perplexity_compute(total: Tensor, count: Tensor) -> Tensor: + """Compute the Perplexity. + + Args: + total: Log probabilities, summed over all samples + count: Number of samples + Returns: + Perplexity + """ + return torch.exp(total / count) + + +def perplexity(preds: Tensor, target: Tensor, ignore_index: Optional[int] = None) -> Tensor: + """Perplexity measures how well a language model predicts a text sample. It's calculated as the average number + of bits per word a model needs to represent the sample. + + Args: + preds: + Probabilities assigned to each token in a sequence with shape [batch_size, seq_len, vocab_size]. + target: + Ground truth values with a shape [batch_size, seq_len]. + ignore_index: + Integer specifying a target class to ignore. If given, this class index does not contribute + to the returned score. + + Returns: + Perplexity value + + Examples: + >>> import torch + >>> preds = torch.rand(2, 8, 5, generator=torch.manual_seed(22)) + >>> target = torch.randint(5, (2, 8), generator=torch.manual_seed(22)) + >>> target[0, 6:] = -100 + >>> perplexity(preds, target, ignore_index=-100) + tensor(5.2545) + """ + total, count = _perplexity_update(preds, target, ignore_index) + return _perplexity_compute(total, count) diff --git a/src/torchmetrics/text/__init__.py b/src/torchmetrics/text/__init__.py index 27e95d6ab6b..6b05fcd5c2c 100644 --- a/src/torchmetrics/text/__init__.py +++ b/src/torchmetrics/text/__init__.py @@ -16,6 +16,7 @@ from torchmetrics.text.chrf import CHRFScore # noqa: F401 from torchmetrics.text.eed import ExtendedEditDistance # noqa: F401 from torchmetrics.text.mer import MatchErrorRate # noqa: F401 +from torchmetrics.text.perplexity import Perplexity # noqa: F401 from torchmetrics.text.sacre_bleu import SacreBLEUScore # noqa: F401 from torchmetrics.text.squad import SQuAD # noqa: F401 from torchmetrics.text.ter import TranslationEditRate # noqa: F401 diff --git a/src/torchmetrics/text/perplexity.py b/src/torchmetrics/text/perplexity.py new file mode 100644 index 00000000000..e0c2efa4c17 --- /dev/null +++ b/src/torchmetrics/text/perplexity.py @@ -0,0 +1,81 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional + +from torch import Tensor, tensor + +from torchmetrics.functional.text.perplexity import _perplexity_compute, _perplexity_update +from torchmetrics.metric import Metric + + +class Perplexity(Metric): + r""" + Perplexity measures how well a language model predicts a text sample. It's calculated as the average number of bits + per word a model needs to represent the sample. + + Args: + ignore_index: + Integer specifying a target class to ignore. If given, this class index does not contribute + to the returned score. + kwargs: + Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Examples: + >>> import torch + >>> preds = torch.rand(2, 8, 5, generator=torch.manual_seed(22)) + >>> target = torch.randint(5, (2, 8), generator=torch.manual_seed(22)) + >>> target[0, 6:] = -100 + >>> metric = Perplexity(ignore_index=-100) + >>> metric(preds, target) + tensor(5.2545) + """ + is_differentiable = True + higher_is_better = False + full_state_update = False + total_log_probs: Tensor + count: Tensor + + def __init__( + self, + ignore_index: Optional[int] = None, + **kwargs: Dict[str, Any], + ): + super().__init__(**kwargs) + if ignore_index is not None and not isinstance(ignore_index, int): + raise ValueError(f"Argument `ignore_index` expected to either be `None` or an `int` but got {ignore_index}") + self.ignore_index = ignore_index + self.add_state("total_log_probs", default=tensor(0.0), dist_reduce_fx="sum") + self.add_state("count", default=tensor(0.0), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + """Compute and store intermediate statistics for Perplexity. + + Args: + preds: + Probabilities assigned to each token in a sequence with shape [batch_size, seq_len, vocab_size]. + target: + Ground truth values with a shape [batch_size, seq_len]. + """ + total_log_probs, count = _perplexity_update(preds, target, self.ignore_index) + self.total_log_probs += total_log_probs + self.count += count + + def compute(self) -> Tensor: + """Compute the Perplexity. + + Returns: + Perplexity + """ + return _perplexity_compute(self.total_log_probs, self.count) diff --git a/tests/unittests/text/inputs.py b/tests/unittests/text/inputs.py index 7cc325c4bbd..eef565bb8fc 100644 --- a/tests/unittests/text/inputs.py +++ b/tests/unittests/text/inputs.py @@ -13,8 +13,16 @@ # limitations under the License. from collections import namedtuple +import torch + +from unittests.helpers import seed_all +from unittests.helpers.testers import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, NUM_CLASSES + +seed_all(1) + Input = namedtuple("Input", ["preds", "targets"]) SquadInput = namedtuple("SquadInput", ["preds", "targets", "exact_match", "f1"]) +LogitsInput = namedtuple("LogitsInput", ["preds", "target"]) # example taken from # https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.corpus_bleu and adjusted @@ -107,3 +115,20 @@ # single reference TUPLE_OF_SINGLE_REFERENCES = ((REFERENCE_1A, REFERENCE_1B), (REFERENCE_1B, REFERENCE_1C)) _inputs_single_reference = Input(preds=TUPLE_OF_HYPOTHESES, targets=TUPLE_OF_SINGLE_REFERENCES) + +# Logits-based inputs for perplexity metrics +_logits_inputs_fp32 = LogitsInput( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM, NUM_CLASSES, dtype=torch.float32), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), +) +_logits_inputs_fp64 = LogitsInput( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM, NUM_CLASSES, dtype=torch.float64), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), +) + +MASK_INDEX = -100 +_target_with_mask = _logits_inputs_fp32.target.clone() +_target_with_mask[:, 0, 1:] = MASK_INDEX +_target_with_mask[:, BATCH_SIZE - 1, :] = MASK_INDEX +_logits_inputs_fp32_with_mask = LogitsInput(preds=_logits_inputs_fp32.preds, target=_target_with_mask) +_logits_inputs_fp64_with_mask = LogitsInput(preds=_logits_inputs_fp64.preds, target=_target_with_mask) diff --git a/tests/unittests/text/test_perplexity.py b/tests/unittests/text/test_perplexity.py new file mode 100644 index 00000000000..40b6be0fabe --- /dev/null +++ b/tests/unittests/text/test_perplexity.py @@ -0,0 +1,79 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial + +import pytest +import torch +import torch.nn.functional as F + +from torchmetrics.functional.text.perplexity import perplexity +from torchmetrics.text.perplexity import Perplexity +from unittests.helpers.testers import MetricTester +from unittests.text.inputs import ( + MASK_INDEX, + _logits_inputs_fp32, + _logits_inputs_fp32_with_mask, + _logits_inputs_fp64, + _logits_inputs_fp64_with_mask, +) + + +def _reference_perplexity(preds, target, ignore_index): + """Reference Perplexity metrics based upon PyTorch Cross Entropy.""" + preds = preds.reshape(-1, preds.shape[-1]) + target = target.reshape(-1) + cross_entropy = F.cross_entropy(preds, target) + return torch.exp(cross_entropy) + + +@pytest.mark.parametrize( + "preds, target, ignore_index", + [ + (_logits_inputs_fp32.preds, _logits_inputs_fp32.target, None), + (_logits_inputs_fp64.preds, _logits_inputs_fp64.target, None), + (_logits_inputs_fp32_with_mask.preds, _logits_inputs_fp32_with_mask.target, MASK_INDEX), + (_logits_inputs_fp64_with_mask.preds, _logits_inputs_fp64_with_mask.target, MASK_INDEX), + ], +) +class TestPerplexity(MetricTester): + @pytest.mark.parametrize("ddp", [False, True]) + @pytest.mark.parametrize("dist_sync_on_step", [False, True]) + def test_perplexity_class(self, ddp, dist_sync_on_step, preds, target, ignore_index): + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=Perplexity, + sk_metric=partial(_reference_perplexity, ignore_index=ignore_index), + dist_sync_on_step=dist_sync_on_step, + metric_args={"ignore_index": ignore_index}, + ) + + def test_perplexity_fn(self, preds, target, ignore_index): + self.run_functional_metric_test( + preds, + target, + metric_functional=perplexity, + sk_metric=partial(_reference_perplexity, ignore_index=ignore_index), + metric_args={"ignore_index": ignore_index}, + ) + + def test_accuracy_differentiability(self, preds, target, ignore_index): + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=Perplexity, + metric_functional=perplexity, + metric_args={"ignore_index": ignore_index}, + )