Skip to content

Commit

Permalink
requested changes to location of checks and docs
Browse files Browse the repository at this point in the history
  • Loading branch information
mapmeld committed Sep 28, 2022
1 parent 540cec0 commit 189b228
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 13 deletions.
13 changes: 13 additions & 0 deletions src/transformers/generation_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,19 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to


class TypicalLogitsWarper(LogitsWarper):
r"""
[`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language
Generation](https://arxiv.org/abs/2202.00666) for more information.
Args:
mass (`float`):
Value of typical_p between 0 and 1 inclusive, defaults to 0.9.
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
All filtered values will be set to this float value.
min_tokens_to_keep (`int`, *optional*, defaults to 1):
Minimum number of tokens that cannot be filtered.
"""

def __init__(self, mass: float = 0.9, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
mass = float(mass)
if not (mass > 0 and mass < 1):
Expand Down
16 changes: 3 additions & 13 deletions src/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,8 +392,6 @@ class GenerationMixin:
`do_sample=False`.
- *multinomial sampling* by calling [`~generation_utils.GenerationMixin.sample`] if `num_beams=1` and
`do_sample=True`.
- *typical decoding* by calling [`~generation_utils.GenerationMixin.sample`] if `typical_p` between 0 and 1,
and `do_sample=True`.
- *beam-search decoding* by calling [`~generation_utils.GenerationMixin.beam_search`] if `num_beams>1` and
`do_sample=False`.
- *beam-search multinomial sampling* by calling [`~generation_utils.GenerationMixin.beam_sample`] if
Expand Down Expand Up @@ -1318,17 +1316,6 @@ def generate(
"Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
)

if typical_p is not None:
if typical_p < 0 or typical_p > 1:
raise ValueError("Value of typical_p must be between 0 and 1.")
elif do_sample is False:
raise ValueError(
"Decoder argument `typical_p` must be used in sampling mode. "
"Make sure that `do_sample` is set to `True`."
)
elif is_group_beam_gen_mode is True:
raise ValueError("Decoder argument `typical_p` is not supported with beam groups.")

# 7. prepare distribution pre_processing samplers
logits_processor = self._get_logits_processor(
repetition_penalty=repetition_penalty,
Expand Down Expand Up @@ -1499,6 +1486,9 @@ def generate(
if stopping_criteria.max_length is None:
raise ValueError("`max_length` needs to be a stopping_criteria for now.")

if typical_p is not None:
raise ValueError("Decoder argument `typical_p` is not supported with beam groups.")

# 10. prepare beam search scorer
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
Expand Down

0 comments on commit 189b228

Please sign in to comment.