-
Notifications
You must be signed in to change notification settings - Fork 402
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add the Perplexity metric * Add changelog line and missing imports in __init__ * Fix the test and the examples * Make the mask optional * Fix mypy issues * Fix mypy * Fix modules docs * Fix test * Update docs/source/references/functional.rst * Update docs/source/references/modules.rst * Move `nanmean` to a utils file * docs * update * Update Perplexity metrics and tests to follow our test suite * Update docs to fix doctest * Add ddp test cases which were unintentionally dropped * Fix device placement * space * Apply suggestions from code review * fix mistake Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Daniel Stancl <46073029+stancld@users.noreply.github.com> Co-authored-by: Nicki Skafte Detlefsen <skaftenicki@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: stancld <daniel.stancl@gmail.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
- Loading branch information
1 parent
b49e579
commit c44aca1
Showing
10 changed files
with
354 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
.. customcarditem:: | ||
:header: Perplexity | ||
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/summarization.svg | ||
:tags: Text | ||
|
||
.. include:: ../links.rst | ||
|
||
########## | ||
Perplexity | ||
########## | ||
|
||
Module Interface | ||
________________ | ||
|
||
.. autoclass:: torchmetrics.text.perplexity.Perplexity | ||
:noindex: | ||
|
||
Functional Interface | ||
____________________ | ||
|
||
.. autofunction:: torchmetrics.functional.text.perplexity.perplexity | ||
:noindex: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Optional, Tuple | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from torch import Tensor | ||
|
||
_TORCH_FLOAT_OR_DOUBLE = (torch.float32, torch.float64) | ||
|
||
|
||
def _check_shape_and_type_consistency(preds: Tensor, target: Tensor) -> None: | ||
"""Check shape and type consistency of input vectors. | ||
Args: | ||
preds: | ||
Probabilities assigned to each token in a sequence with shape [batch_size, seq_len, vocab_size]. | ||
target: | ||
Ground truth values with a shape [batch_size, seq_len]. | ||
Raises: | ||
ValueError: | ||
If ``preds`` tensor has no 3 dimensions. | ||
ValueError: | ||
If ``target`` tensor has no 2 dimensions. | ||
ValueError: | ||
If the first two dimensions of ``preds`` and ``target`` do not equal. | ||
TypeError: | ||
If ``preds`` dtype is not one of ``(torch.float16, torch.float32, torch.float64)`` | ||
TypeError: | ||
If ``target`` is not of a type LongTensor (torch.int64) | ||
""" | ||
if len(preds.shape) != 3: | ||
raise ValueError( | ||
"Input tensor `preds` is expected to have 3 dimensions, [batch_size, seq_len, vocab_size]," | ||
f" but got {len(preds.shape)}." | ||
) | ||
if len(target.shape) != 2: | ||
raise ValueError( | ||
"Input tensor `target` is expected to have 2 dimensions, [batch_size, seq_len]," | ||
f" but got {len(target.shape)}." | ||
) | ||
if preds.shape[:2] != target.shape: | ||
raise ValueError( | ||
"Input tensors `preds` and `target` are expected to have equaling first two dimensions," | ||
f" [batch_size, seq_len], but got {preds.shape[:2]} and {target.shape}." | ||
) | ||
if preds.dtype not in _TORCH_FLOAT_OR_DOUBLE: | ||
raise TypeError( | ||
f"Input tensor `preds` is expected to be of a type one of {_TORCH_FLOAT_OR_DOUBLE} but got {preds.dtype}." | ||
) | ||
if target.dtype != torch.int64: | ||
raise TypeError(f"Input tensor `target` is expected to be of a type {torch.int64} but got {target.dtype}.") | ||
|
||
|
||
def _perplexity_update(preds: Tensor, target: Tensor, ignore_index: Optional[int] = None) -> Tuple[Tensor, Tensor]: | ||
"""Compute intermediate statistics for Perplexity. | ||
Args: | ||
preds: | ||
Probabilities assigned to each token in a sequence with shape [batch_size, seq_len, vocab_size]. | ||
target: | ||
Ground truth values with a shape [batch_size, seq_len]. | ||
ignore_index: | ||
Integer specifying a target class to ignore. If given, this class index does not contribute | ||
to the returned score. | ||
Returns: | ||
Log probabilities, summed over all samples | ||
Number of samples | ||
""" | ||
_check_shape_and_type_consistency(preds, target) | ||
|
||
probs = F.softmax(preds.reshape(-1, preds.shape[-1]), dim=1) | ||
target = target.reshape(-1) | ||
|
||
if ignore_index is not None: | ||
mask = target.ne(ignore_index) | ||
target = target.where(target != ignore_index, torch.tensor(0, device=target.device)) | ||
else: | ||
mask = torch.ones_like(target, dtype=torch.bool) | ||
|
||
probs = probs[:, target].diagonal()[mask] | ||
total_log_probs = -probs.log().sum() | ||
count = mask.sum() | ||
|
||
return total_log_probs, count | ||
|
||
|
||
def _perplexity_compute(total: Tensor, count: Tensor) -> Tensor: | ||
"""Compute the Perplexity. | ||
Args: | ||
total: Log probabilities, summed over all samples | ||
count: Number of samples | ||
Returns: | ||
Perplexity | ||
""" | ||
return torch.exp(total / count) | ||
|
||
|
||
def perplexity(preds: Tensor, target: Tensor, ignore_index: Optional[int] = None) -> Tensor: | ||
"""Perplexity measures how well a language model predicts a text sample. It's calculated as the average number | ||
of bits per word a model needs to represent the sample. | ||
Args: | ||
preds: | ||
Probabilities assigned to each token in a sequence with shape [batch_size, seq_len, vocab_size]. | ||
target: | ||
Ground truth values with a shape [batch_size, seq_len]. | ||
ignore_index: | ||
Integer specifying a target class to ignore. If given, this class index does not contribute | ||
to the returned score. | ||
Returns: | ||
Perplexity value | ||
Examples: | ||
>>> import torch | ||
>>> preds = torch.rand(2, 8, 5, generator=torch.manual_seed(22)) | ||
>>> target = torch.randint(5, (2, 8), generator=torch.manual_seed(22)) | ||
>>> target[0, 6:] = -100 | ||
>>> perplexity(preds, target, ignore_index=-100) | ||
tensor(5.2545) | ||
""" | ||
total, count = _perplexity_update(preds, target, ignore_index) | ||
return _perplexity_compute(total, count) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Any, Dict, Optional | ||
|
||
from torch import Tensor, tensor | ||
|
||
from torchmetrics.functional.text.perplexity import _perplexity_compute, _perplexity_update | ||
from torchmetrics.metric import Metric | ||
|
||
|
||
class Perplexity(Metric): | ||
r""" | ||
Perplexity measures how well a language model predicts a text sample. It's calculated as the average number of bits | ||
per word a model needs to represent the sample. | ||
Args: | ||
ignore_index: | ||
Integer specifying a target class to ignore. If given, this class index does not contribute | ||
to the returned score. | ||
kwargs: | ||
Additional keyword arguments, see :ref:`Metric kwargs` for more info. | ||
Examples: | ||
>>> import torch | ||
>>> preds = torch.rand(2, 8, 5, generator=torch.manual_seed(22)) | ||
>>> target = torch.randint(5, (2, 8), generator=torch.manual_seed(22)) | ||
>>> target[0, 6:] = -100 | ||
>>> metric = Perplexity(ignore_index=-100) | ||
>>> metric(preds, target) | ||
tensor(5.2545) | ||
""" | ||
is_differentiable = True | ||
higher_is_better = False | ||
full_state_update = False | ||
total_log_probs: Tensor | ||
count: Tensor | ||
|
||
def __init__( | ||
self, | ||
ignore_index: Optional[int] = None, | ||
**kwargs: Dict[str, Any], | ||
): | ||
super().__init__(**kwargs) | ||
if ignore_index is not None and not isinstance(ignore_index, int): | ||
raise ValueError(f"Argument `ignore_index` expected to either be `None` or an `int` but got {ignore_index}") | ||
self.ignore_index = ignore_index | ||
self.add_state("total_log_probs", default=tensor(0.0), dist_reduce_fx="sum") | ||
self.add_state("count", default=tensor(0.0), dist_reduce_fx="sum") | ||
|
||
def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore | ||
"""Compute and store intermediate statistics for Perplexity. | ||
Args: | ||
preds: | ||
Probabilities assigned to each token in a sequence with shape [batch_size, seq_len, vocab_size]. | ||
target: | ||
Ground truth values with a shape [batch_size, seq_len]. | ||
""" | ||
total_log_probs, count = _perplexity_update(preds, target, self.ignore_index) | ||
self.total_log_probs += total_log_probs | ||
self.count += count | ||
|
||
def compute(self) -> Tensor: | ||
"""Compute the Perplexity. | ||
Returns: | ||
Perplexity | ||
""" | ||
return _perplexity_compute(self.total_log_probs, self.count) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.