Skip to content

RepetitionPenaltyLogitsProcessor and EncoderRepetitionPenaltyLogitsProcessor contains incorrect and unclear docstrings #25970

@larekrow

Description

@larekrow
  1. The class docstring of RepetitionPenaltyLogitsProcessor says "tokens with higher scores are less likely to be selected". However, according to the paper which states that "this penalized sampling works by discounting the scores of previously generated tokens" and the code which lowers the score when penalizing tokens (e.g. by multiplying a negative score with a 1.2 penalty, 1.2 being a value the paper highlighted), the docstring should be corrected to say that "tokens with higher scores are more likely to be selected".

score = torch.gather(scores, 1, input_ids)
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
score = torch.where(score < 0, score * self.penalty, score / self.penalty)

  1. EncoderRepetitionPenaltyLogitsProcessor requires an additional encoder_input_ids arg which docstring says "the encoder_input_ids that should not be repeated within the decoder ids". However, according to the class docstring, Add hallucination filter in generate() #18354 (comment), and the code with increases the score of tokens found within the original input ids (e.g. by multiplying a negative score with a 1 / 2 = 0.5 penalty, where hallucination_penalty = 2 and is a value the PR author used), these are the ids that should be repeated within the decoder ids.

self.penalty = 1 / penalty
self.encoder_input_ids = encoder_input_ids
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
score = torch.gather(scores, 1, self.encoder_input_ids)
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
score = torch.where(score < 0, score * self.penalty, score / self.penalty)

  1. Both RepetitionPenaltyLogitsProcessor and EncoderRepetitionPenaltyLogitsProcessor require a penalty input, which is enforced as a positive float. However, this input only works as expected when penalty > 1. If 0 < penalty < 1 is given, the "penalty" becomes a "reward". The docstring does not mention this in any way.

RepetitionPenaltyLogitsProcessor

if not isinstance(penalty, float) or not (penalty > 0):
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")

EncoderRepetitionPenaltyLogitsProcessor

if not isinstance(penalty, float) or not (penalty > 0):
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")

@gante

Before delving deeper into the source code and other resources, I was truly confused by the contradicting messages. I hope this will be rectified for other users.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions