Skip to content

Commit

Permalink
Add the Perplexity metric (#922)
Browse files Browse the repository at this point in the history
* Add the Perplexity metric

* Add changelog line and missing imports in __init__

* Fix the test and the examples

* Make the mask optional

* Fix mypy issues

* Fix mypy

* Fix modules docs

* Fix test

* Update docs/source/references/functional.rst

* Update docs/source/references/modules.rst

* Move `nanmean` to a utils file

* docs

* update

* Update Perplexity metrics and tests to follow our test suite

* Update docs to fix doctest

* Add ddp test cases which were unintentionally dropped

* Fix device placement

* space

* Apply suggestions from code review

* fix mistake

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Daniel Stancl <46073029+stancld@users.noreply.github.com>
Co-authored-by: Nicki Skafte Detlefsen <skaftenicki@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: stancld <daniel.stancl@gmail.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
7 people authored Jun 30, 2022
1 parent b49e579 commit c44aca1
Show file tree
Hide file tree
Showing 10 changed files with 354 additions and 2 deletions.
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Added global option `sync_on_compute` to disable automatic syncronization when `compute` is called ([#1107](https://github.dev/Lightning-AI/metrics/pull/1107))
- Added `Perplexity` metric ([#922](https://github.com/PyTorchLightning/metrics/pull/922))


-
- Added global option `sync_on_compute` to disable automatic syncronization when `compute` is called ([#1107](https://github.dev/Lightning-AI/metrics/pull/1107))


### Changed
Expand Down
22 changes: 22 additions & 0 deletions docs/source/text/perplexity.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
.. customcarditem::
:header: Perplexity
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/summarization.svg
:tags: Text

.. include:: ../links.rst

##########
Perplexity
##########

Module Interface
________________

.. autoclass:: torchmetrics.text.perplexity.Perplexity
:noindex:

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.text.perplexity.perplexity
:noindex:
2 changes: 2 additions & 0 deletions src/torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
CHRFScore,
ExtendedEditDistance,
MatchErrorRate,
Perplexity,
SacreBLEUScore,
SQuAD,
TranslationEditRate,
Expand Down Expand Up @@ -157,6 +158,7 @@
"MultiScaleStructuralSimilarityIndexMeasure",
"PearsonCorrCoef",
"PermutationInvariantTraining",
"Perplexity",
"Precision",
"PrecisionRecallCurve",
"PeakSignalNoiseRatio",
Expand Down
2 changes: 2 additions & 0 deletions src/torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
from torchmetrics.functional.text.chrf import chrf_score
from torchmetrics.functional.text.eed import extended_edit_distance
from torchmetrics.functional.text.mer import match_error_rate
from torchmetrics.functional.text.perplexity import perplexity
from torchmetrics.functional.text.rouge import rouge_score
from torchmetrics.functional.text.sacre_bleu import sacre_bleu_score
from torchmetrics.functional.text.squad import squad
Expand Down Expand Up @@ -131,6 +132,7 @@
"pairwise_manhattan_distance",
"pearson_corrcoef",
"permutation_invariant_training",
"perplexity",
"pit_permutate",
"precision",
"precision_recall",
Expand Down
1 change: 1 addition & 0 deletions src/torchmetrics/functional/text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from torchmetrics.functional.text.chrf import chrf_score # noqa: F401
from torchmetrics.functional.text.eed import extended_edit_distance # noqa: F401
from torchmetrics.functional.text.mer import match_error_rate # noqa: F401
from torchmetrics.functional.text.perplexity import perplexity # noqa: F401
from torchmetrics.functional.text.sacre_bleu import sacre_bleu_score # noqa: F401
from torchmetrics.functional.text.squad import squad # noqa: F401
from torchmetrics.functional.text.ter import translation_edit_rate # noqa: F401
Expand Down
139 changes: 139 additions & 0 deletions src/torchmetrics/functional/text/perplexity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Tuple

import torch
import torch.nn.functional as F
from torch import Tensor

_TORCH_FLOAT_OR_DOUBLE = (torch.float32, torch.float64)


def _check_shape_and_type_consistency(preds: Tensor, target: Tensor) -> None:
"""Check shape and type consistency of input vectors.
Args:
preds:
Probabilities assigned to each token in a sequence with shape [batch_size, seq_len, vocab_size].
target:
Ground truth values with a shape [batch_size, seq_len].
Raises:
ValueError:
If ``preds`` tensor has no 3 dimensions.
ValueError:
If ``target`` tensor has no 2 dimensions.
ValueError:
If the first two dimensions of ``preds`` and ``target`` do not equal.
TypeError:
If ``preds`` dtype is not one of ``(torch.float16, torch.float32, torch.float64)``
TypeError:
If ``target`` is not of a type LongTensor (torch.int64)
"""
if len(preds.shape) != 3:
raise ValueError(
"Input tensor `preds` is expected to have 3 dimensions, [batch_size, seq_len, vocab_size],"
f" but got {len(preds.shape)}."
)
if len(target.shape) != 2:
raise ValueError(
"Input tensor `target` is expected to have 2 dimensions, [batch_size, seq_len],"
f" but got {len(target.shape)}."
)
if preds.shape[:2] != target.shape:
raise ValueError(
"Input tensors `preds` and `target` are expected to have equaling first two dimensions,"
f" [batch_size, seq_len], but got {preds.shape[:2]} and {target.shape}."
)
if preds.dtype not in _TORCH_FLOAT_OR_DOUBLE:
raise TypeError(
f"Input tensor `preds` is expected to be of a type one of {_TORCH_FLOAT_OR_DOUBLE} but got {preds.dtype}."
)
if target.dtype != torch.int64:
raise TypeError(f"Input tensor `target` is expected to be of a type {torch.int64} but got {target.dtype}.")


def _perplexity_update(preds: Tensor, target: Tensor, ignore_index: Optional[int] = None) -> Tuple[Tensor, Tensor]:
"""Compute intermediate statistics for Perplexity.
Args:
preds:
Probabilities assigned to each token in a sequence with shape [batch_size, seq_len, vocab_size].
target:
Ground truth values with a shape [batch_size, seq_len].
ignore_index:
Integer specifying a target class to ignore. If given, this class index does not contribute
to the returned score.
Returns:
Log probabilities, summed over all samples
Number of samples
"""
_check_shape_and_type_consistency(preds, target)

probs = F.softmax(preds.reshape(-1, preds.shape[-1]), dim=1)
target = target.reshape(-1)

if ignore_index is not None:
mask = target.ne(ignore_index)
target = target.where(target != ignore_index, torch.tensor(0, device=target.device))
else:
mask = torch.ones_like(target, dtype=torch.bool)

probs = probs[:, target].diagonal()[mask]
total_log_probs = -probs.log().sum()
count = mask.sum()

return total_log_probs, count


def _perplexity_compute(total: Tensor, count: Tensor) -> Tensor:
"""Compute the Perplexity.
Args:
total: Log probabilities, summed over all samples
count: Number of samples
Returns:
Perplexity
"""
return torch.exp(total / count)


def perplexity(preds: Tensor, target: Tensor, ignore_index: Optional[int] = None) -> Tensor:
"""Perplexity measures how well a language model predicts a text sample. It's calculated as the average number
of bits per word a model needs to represent the sample.
Args:
preds:
Probabilities assigned to each token in a sequence with shape [batch_size, seq_len, vocab_size].
target:
Ground truth values with a shape [batch_size, seq_len].
ignore_index:
Integer specifying a target class to ignore. If given, this class index does not contribute
to the returned score.
Returns:
Perplexity value
Examples:
>>> import torch
>>> preds = torch.rand(2, 8, 5, generator=torch.manual_seed(22))
>>> target = torch.randint(5, (2, 8), generator=torch.manual_seed(22))
>>> target[0, 6:] = -100
>>> perplexity(preds, target, ignore_index=-100)
tensor(5.2545)
"""
total, count = _perplexity_update(preds, target, ignore_index)
return _perplexity_compute(total, count)
1 change: 1 addition & 0 deletions src/torchmetrics/text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torchmetrics.text.chrf import CHRFScore # noqa: F401
from torchmetrics.text.eed import ExtendedEditDistance # noqa: F401
from torchmetrics.text.mer import MatchErrorRate # noqa: F401
from torchmetrics.text.perplexity import Perplexity # noqa: F401
from torchmetrics.text.sacre_bleu import SacreBLEUScore # noqa: F401
from torchmetrics.text.squad import SQuAD # noqa: F401
from torchmetrics.text.ter import TranslationEditRate # noqa: F401
Expand Down
81 changes: 81 additions & 0 deletions src/torchmetrics/text/perplexity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, Optional

from torch import Tensor, tensor

from torchmetrics.functional.text.perplexity import _perplexity_compute, _perplexity_update
from torchmetrics.metric import Metric


class Perplexity(Metric):
r"""
Perplexity measures how well a language model predicts a text sample. It's calculated as the average number of bits
per word a model needs to represent the sample.
Args:
ignore_index:
Integer specifying a target class to ignore. If given, this class index does not contribute
to the returned score.
kwargs:
Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Examples:
>>> import torch
>>> preds = torch.rand(2, 8, 5, generator=torch.manual_seed(22))
>>> target = torch.randint(5, (2, 8), generator=torch.manual_seed(22))
>>> target[0, 6:] = -100
>>> metric = Perplexity(ignore_index=-100)
>>> metric(preds, target)
tensor(5.2545)
"""
is_differentiable = True
higher_is_better = False
full_state_update = False
total_log_probs: Tensor
count: Tensor

def __init__(
self,
ignore_index: Optional[int] = None,
**kwargs: Dict[str, Any],
):
super().__init__(**kwargs)
if ignore_index is not None and not isinstance(ignore_index, int):
raise ValueError(f"Argument `ignore_index` expected to either be `None` or an `int` but got {ignore_index}")
self.ignore_index = ignore_index
self.add_state("total_log_probs", default=tensor(0.0), dist_reduce_fx="sum")
self.add_state("count", default=tensor(0.0), dist_reduce_fx="sum")

def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
"""Compute and store intermediate statistics for Perplexity.
Args:
preds:
Probabilities assigned to each token in a sequence with shape [batch_size, seq_len, vocab_size].
target:
Ground truth values with a shape [batch_size, seq_len].
"""
total_log_probs, count = _perplexity_update(preds, target, self.ignore_index)
self.total_log_probs += total_log_probs
self.count += count

def compute(self) -> Tensor:
"""Compute the Perplexity.
Returns:
Perplexity
"""
return _perplexity_compute(self.total_log_probs, self.count)
25 changes: 25 additions & 0 deletions tests/unittests/text/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,16 @@
# limitations under the License.
from collections import namedtuple

import torch

from unittests.helpers import seed_all
from unittests.helpers.testers import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, NUM_CLASSES

seed_all(1)

Input = namedtuple("Input", ["preds", "targets"])
SquadInput = namedtuple("SquadInput", ["preds", "targets", "exact_match", "f1"])
LogitsInput = namedtuple("LogitsInput", ["preds", "target"])

# example taken from
# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.corpus_bleu and adjusted
Expand Down Expand Up @@ -107,3 +115,20 @@
# single reference
TUPLE_OF_SINGLE_REFERENCES = ((REFERENCE_1A, REFERENCE_1B), (REFERENCE_1B, REFERENCE_1C))
_inputs_single_reference = Input(preds=TUPLE_OF_HYPOTHESES, targets=TUPLE_OF_SINGLE_REFERENCES)

# Logits-based inputs for perplexity metrics
_logits_inputs_fp32 = LogitsInput(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM, NUM_CLASSES, dtype=torch.float32),
target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)),
)
_logits_inputs_fp64 = LogitsInput(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM, NUM_CLASSES, dtype=torch.float64),
target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)),
)

MASK_INDEX = -100
_target_with_mask = _logits_inputs_fp32.target.clone()
_target_with_mask[:, 0, 1:] = MASK_INDEX
_target_with_mask[:, BATCH_SIZE - 1, :] = MASK_INDEX
_logits_inputs_fp32_with_mask = LogitsInput(preds=_logits_inputs_fp32.preds, target=_target_with_mask)
_logits_inputs_fp64_with_mask = LogitsInput(preds=_logits_inputs_fp64.preds, target=_target_with_mask)
Loading

0 comments on commit c44aca1

Please sign in to comment.