Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exclude set of tokens from rep penalty calculation #1170

Closed
teknium1 opened this issue Oct 18, 2023 · 4 comments
Closed

Exclude set of tokens from rep penalty calculation #1170

teknium1 opened this issue Oct 18, 2023 · 4 comments
Labels

Comments

@teknium1
Copy link

Feature request

Hello, I would like to propose a feature that allows you to set a list of tokens, or even token sequences, that can be excluded from repetition penalty calculations.

The reasoning for this being that, given a prompt format for multiturn, such as:

user: abc
assistant: def
user: ghi
assistant: jkl

Or even worse, a format like ChatML, where it is in now standard case using <|im_end|> as a stopping token and included in every turn, it seems only logical that given these tokens all appear in every turn, that, especially in short token turn sequences, repetition penalty will destroy the validity of these prompt formats.

Motivation

image

Your contribution

I will soon attempt a solution but I really think I will get it wrong, nonetheless, here is my theory on how to do it in TGI:

This seems to be the rep penalty code:

class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor):
    r"""
    [`LogitsProcessor`] enforcing an exponential penalty on repeated sequences.
    This version allows for a separate value for each sample and runs inplace when possible.
    It doesn't validate inputs.

    Args:
        repetition_penalty (`List[float]`):
            The parameter for repetition penalty. 1.0 means no penalty. See [this
            paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
    """

    def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device):
        self.penalty = penalty
        self.penalty_tensor = torch.tensor(
            penalty, dtype=dtype, device=device
        ).unsqueeze(1)

    def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
        score = torch.gather(scores, 1, input_ids)

        # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
        score = torch.where(
            score < 0, score * self.penalty_tensor, score / self.penalty_tensor
        )

        scores.scatter_(1, input_ids, score)
        return scores

    def filter(self, indices):
        self.penalty = [self.penalty[i] for i in indices]
        if any([x != 1.0 for x in self.penalty]):
            self.penalty_tensor = self.penalty_tensor[indices]
            return self
        return None

I'm thinking we take the input_id's before getting scored and simply replacing it with input_id's that remove any from some variable setting a list of token ids or token strings->id's.

@OlivierDehaene
Copy link
Contributor

The repetition penalty is a one-off thing i.e. it doesn't matter how many time the tokens is present the penalty will be the same.

I suppose you CAN get unbounded generation but I'm not sure if it DOES happen. The probability associeted with these two tokens when generating are so high that the penalty/sampling should not matter.

Do you have an experiment that could verify that this is indeed hapening?

@teknium1
Copy link
Author

The repetition penalty is a one-off thing i.e. it doesn't matter how many time the tokens is present the penalty will be the same.

I suppose you CAN get unbounded generation but I'm not sure if it DOES happen. The probability associeted with these two tokens when generating are so high that the penalty/sampling should not matter.

Do you have an experiment that could verify that this is indeed hapening?

I havent noticed it in my own inference of teknium/OpenHermes-2-Mistral-7B, despite it using EOS as a turn delimiter <|im_end|>, but according to this HF employee, LMSys has something you could theoretically test it on. But I wasnt aware of the one-off thing when I started these feature requests, and thus, if it makes it through 2 turns, it should never become a problem. Does Rep penalty fall off further away from when the token was seen last? I speculated that it may only be for short token turns that this appears - if it does at all - so maybe in roleplay or very terse answering models this may be a problem if recency is a big factor

@OlivierDehaene
Copy link
Contributor

Does Rep penalty fall off further away from when the token was seen last?

No, it's constant. It could be interesting to add a windowing on it indeed.

Copy link

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.

@github-actions github-actions bot added the Stale label Dec 18, 2023
@github-actions github-actions bot closed this as not planned Won't fix, can't repro, duplicate, stale Dec 24, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants