From b036e261b2877b0e8b4fa67c233635a690456375 Mon Sep 17 00:00:00 2001 From: larekrow <127832774+larekrow@users.noreply.github.com> Date: Fri, 13 Oct 2023 16:28:52 +0800 Subject: [PATCH 1/2] Update logits_process.py docstrings + match arg fields to __init__'s --- src/transformers/generation/logits_process.py | 26 ++++++++++++++----- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 14f772ab6c99ee..edd9805f845296 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -276,9 +276,14 @@ class RepetitionPenaltyLogitsProcessor(LogitsProcessor): selected. The formula can be seen in the original [paper](https://arxiv.org/pdf/1909.05858.pdf). According to the paper a penalty of around 1.2 yields a good balance between truthful generation and lack of repetition. + This technique can also be used to reward and thus encourage repetition in a similar manner. To penalize and reduce + repetition, use `penalty` values above 1.0, where a higher value penalizes more strongly. To reward and encourage + repetition, use `penalty` values between 0.0 and 1.0, where a lower value rewards more strongly. + Args: - repetition_penalty (`float`): - The parameter for repetition penalty. 1.0 means no penalty. See [this + penalty (`float`): + The parameter for repetition penalty. 1.0 means no penalty. Above 1.0 penalizes previously generated tokens. + Between 0.0 and 1.0 rewards previously generated tokens. See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. Examples: @@ -313,7 +318,7 @@ def __init__(self, penalty: float): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: score = torch.gather(scores, 1, input_ids) - # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability + # if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities score = torch.where(score < 0, score * self.penalty, score / self.penalty) scores.scatter_(1, input_ids, score) @@ -322,11 +327,18 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to class EncoderRepetitionPenaltyLogitsProcessor(LogitsProcessor): r""" - [`LogitsProcessor`] enforcing an exponential penalty on tokens that are not in the original input. + [`LogitsProcessor`] that avoids hallucination by boosting the probabilities of tokens found within the original + input. + + This technique can also be used to reward and thus encourage hallucination (or creativity) in a similar manner. To + penalize and reduce hallucination, use `penalty` values above 1.0, where a higher value penalizes more strongly. To + reward and encourage hallucination, use `penalty` values between 0.0 and 1.0, where a lower value rewards more + strongly. Args: - hallucination_penalty (`float`): - The parameter for hallucination penalty. 1.0 means no penalty. + penalty (`float`): + The parameter for hallucination penalty. 1.0 means no penalty. Above 1.0 penalizes hallucination. Between + 0.0 and 1.0 rewards hallucination. encoder_input_ids (`torch.LongTensor`): The encoder_input_ids that should be repeated within the decoder ids. """ @@ -342,7 +354,7 @@ def __init__(self, penalty: float, encoder_input_ids: torch.LongTensor): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: score = torch.gather(scores, 1, self.encoder_input_ids) - # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability + # if score < 0 then hallucination penalty has to be multiplied to increase the token probabilities score = torch.where(score < 0, score * self.penalty, score / self.penalty) scores.scatter_(1, self.encoder_input_ids, score) From 423f6b2f72e32c4bc68aa6da8e96cea6b997799a Mon Sep 17 00:00:00 2001 From: larekrow Date: Tue, 17 Oct 2023 14:18:10 +0800 Subject: [PATCH 2/2] Ran `make style` --- src/transformers/generation/logits_process.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index edd9805f845296..3ab689b55a767d 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -282,8 +282,8 @@ class RepetitionPenaltyLogitsProcessor(LogitsProcessor): Args: penalty (`float`): - The parameter for repetition penalty. 1.0 means no penalty. Above 1.0 penalizes previously generated tokens. - Between 0.0 and 1.0 rewards previously generated tokens. See [this + The parameter for repetition penalty. 1.0 means no penalty. Above 1.0 penalizes previously generated + tokens. Between 0.0 and 1.0 rewards previously generated tokens. See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. Examples: