Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update logits_process.py docstrings #25971

Merged
merged 1 commit into from
Sep 12, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
[`LogitsProcessor`] that prevents the repetition of previous tokens through an exponential penalty. This technique
shares some similarities with coverage mechanisms and other aimed at reducing repetition. During the text
generation process, the probability distribution for the next token is determined using a formula that incorporates
token scores based on their occurrence in the generated sequence. Tokens with higher scores are less likely to be
token scores based on their occurrence in the generated sequence. Tokens with higher scores are more likely to be
selected. The formula can be seen in the original [paper](https://arxiv.org/pdf/1909.05858.pdf). According to the
paper a penalty of around 1.2 yields a good balance between truthful generation and lack of repetition.

Expand Down Expand Up @@ -328,7 +328,7 @@ class EncoderRepetitionPenaltyLogitsProcessor(LogitsProcessor):
hallucination_penalty (`float`):
The parameter for hallucination penalty. 1.0 means no penalty.
encoder_input_ids (`torch.LongTensor`):
The encoder_input_ids that should not be repeated within the decoder ids.
The encoder_input_ids that should be repeated within the decoder ids.
"""

def __init__(self, penalty: float, encoder_input_ids: torch.LongTensor):
Expand Down