Skip to content

Commit

Permalink
Rename additive_repetition_penalty to presence_penalty, add frequency…
Browse files Browse the repository at this point in the history
…_penalty (oobabooga#4376)
  • Loading branch information
tdrussell authored Oct 25, 2023
1 parent ef1489c commit 72f6fc6
Show file tree
Hide file tree
Showing 14 changed files with 64 additions and 30 deletions.
3 changes: 2 additions & 1 deletion api-examples/api-example-chat-stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ async def run(user_input, history):
'tfs': 1,
'top_a': 0,
'repetition_penalty': 1.18,
'additive_repetition_penalty': 0,
'presence_penalty': 0,
'frequency_penalty': 0,
'repetition_penalty_range': 0,
'top_k': 40,
'min_length': 0,
Expand Down
3 changes: 2 additions & 1 deletion api-examples/api-example-chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def run(user_input, history):
'tfs': 1,
'top_a': 0,
'repetition_penalty': 1.18,
'additive_repetition_penalty': 0,
'presence_penalty': 0,
'frequency_penalty': 0,
'repetition_penalty_range': 0,
'top_k': 40,
'min_length': 0,
Expand Down
3 changes: 2 additions & 1 deletion api-examples/api-example-stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ async def run(context):
'tfs': 1,
'top_a': 0,
'repetition_penalty': 1.18,
'additive_repetition_penalty': 0,
'presence_penalty': 0,
'frequency_penalty': 0,
'repetition_penalty_range': 0,
'top_k': 40,
'min_length': 0,
Expand Down
3 changes: 2 additions & 1 deletion api-examples/api-example.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def run(prompt):
'tfs': 1,
'top_a': 0,
'repetition_penalty': 1.18,
'additive_repetition_penalty': 0,
'presence_penalty': 0,
'frequency_penalty': 0,
'repetition_penalty_range': 0,
'top_k': 40,
'min_length': 0,
Expand Down
3 changes: 2 additions & 1 deletion docs/03 ‐ Parameters Tab.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ For more information about the parameters, the [transformers documentation](http
* **top_p**: If not set to 1, select tokens with probabilities adding up to less than this number. Higher value = higher range of possible random results.
* **top_k**: Similar to top_p, but select instead only the top_k most likely tokens. Higher value = higher range of possible random results.
* **repetition_penalty**: Penalty factor for repeating prior tokens. 1 means no penalty, higher value = less repetition, lower value = more repetition.
* **additive_repetition_penalty**: Similar to repetition_penalty, but with an additive offset on the raw token scores instead of a multiplicative factor. It may generate better results. 0 means no penalty, higher value = less repetition, lower value = more repetition.
* **presence_penalty**: Similar to repetition_penalty, but with an additive offset on the raw token scores instead of a multiplicative factor. It may generate better results. 0 means no penalty, higher value = less repetition, lower value = more repetition. Previously called "additive_repetition_penalty".
* **frequency_penalty**: Repetition penalty that scales based on how many times the token has appeared in the context. Be careful with this; there's no limit to how much a token can be penalized.
* **repetition_penalty_range**: The number of most recent tokens to consider for repetition penalty. 0 makes all tokens be used.
* **typical_p**: If not set to 1, select only tokens that are at least this much more likely to appear than random tokens, given the prior text.
* **tfs**: Tries to detect a tail of low-probability tokens in the distribution and removes those tokens. See [this blog post](https://www.trentonbricken.com/Tail-Free-Sampling/) for details. The closer to 0, the more discarded tokens.
Expand Down
3 changes: 2 additions & 1 deletion extensions/api/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def build_parameters(body, chat=False):
'tfs': float(body.get('tfs', 1)),
'top_a': float(body.get('top_a', 0)),
'repetition_penalty': float(body.get('repetition_penalty', body.get('rep_pen', 1.1))),
'additive_repetition_penalty': float(body.get('additive_repetition_penalty', body.get('additive_rep_pen', 0))),
'presence_penalty': float(body.get('presence_penalty', body.get('presence_pen', 0))),
'frequency_penalty': float(body.get('frequency_penalty', body.get('frequency_pen', 0))),
'repetition_penalty_range': int(body.get('repetition_penalty_range', 0)),
'encoder_repetition_penalty': float(body.get('encoder_repetition_penalty', 1.0)),
'top_k': int(body.get('top_k', 0)),
Expand Down
3 changes: 2 additions & 1 deletion extensions/openai/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
'top_p': 1.0,
'top_k': 1, # choose 20 for chat in absence of another default
'repetition_penalty': 1.18,
'additive_repetition_penalty': 0,
'presence_penalty': 0,
'frequency_penalty': 0,
'repetition_penalty_range': 0,
'encoder_repetition_penalty': 1.0,
'suffix': None,
Expand Down
2 changes: 2 additions & 0 deletions modules/llamacpp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ def generate(self, prompt, state, callback=None):
top_p=state['top_p'],
top_k=state['top_k'],
repeat_penalty=state['repetition_penalty'],
presence_penalty=state['presence_penalty'],
frequency_penalty=state['frequency_penalty'],
tfs_z=state['tfs'],
mirostat_mode=int(state['mirostat_mode']),
mirostat_tau=state['mirostat_tau'],
Expand Down
23 changes: 16 additions & 7 deletions modules/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,8 @@
'tfs',
'top_a',
'repetition_penalty',
'additive_repetition_penalty',
'presence_penalty',
'frequency_penalty',
'repetition_penalty_range',
'encoder_repetition_penalty',
'no_repeat_ngram_size',
Expand Down Expand Up @@ -186,7 +187,8 @@
'tfs',
'top_a',
'repetition_penalty',
'additive_repetition_penalty',
'presence_penalty',
'frequency_penalty',
'repetition_penalty_range',
'encoder_repetition_penalty',
'no_repeat_ngram_size',
Expand Down Expand Up @@ -245,7 +247,8 @@
'tfs',
'top_a',
'repetition_penalty',
'additive_repetition_penalty',
'presence_penalty',
'frequency_penalty',
'repetition_penalty_range',
'encoder_repetition_penalty',
'no_repeat_ngram_size',
Expand Down Expand Up @@ -275,7 +278,8 @@
'tfs',
'top_a',
'repetition_penalty',
'additive_repetition_penalty',
'presence_penalty',
'frequency_penalty',
'repetition_penalty_range',
'encoder_repetition_penalty',
'no_repeat_ngram_size',
Expand Down Expand Up @@ -309,7 +313,8 @@
'tfs',
'top_a',
'repetition_penalty',
'additive_repetition_penalty',
'presence_penalty',
'frequency_penalty',
'repetition_penalty_range',
'encoder_repetition_penalty',
'no_repeat_ngram_size',
Expand Down Expand Up @@ -339,6 +344,8 @@
'top_k',
'tfs',
'repetition_penalty',
'presence_penalty',
'frequency_penalty',
'mirostat_mode',
'mirostat_tau',
'mirostat_eta',
Expand All @@ -357,7 +364,8 @@
'tfs',
'top_a',
'repetition_penalty',
'additive_repetition_penalty',
'presence_penalty',
'frequency_penalty',
'repetition_penalty_range',
'encoder_repetition_penalty',
'no_repeat_ngram_size',
Expand Down Expand Up @@ -394,7 +402,8 @@
'tfs',
'top_a',
'repetition_penalty',
'additive_repetition_penalty',
'presence_penalty',
'frequency_penalty',
'repetition_penalty_range',
'encoder_repetition_penalty',
'no_repeat_ngram_size',
Expand Down
3 changes: 2 additions & 1 deletion modules/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ def default_preset():
'tfs': 1,
'top_a': 0,
'repetition_penalty': 1,
'additive_repetition_penalty': 0,
'presence_penalty': 0,
'frequency_penalty': 0,
'repetition_penalty_range': 0,
'encoder_repetition_penalty': 1,
'no_repeat_ngram_size': 0,
Expand Down
37 changes: 25 additions & 12 deletions modules/sampler_hijack.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,24 +139,35 @@ class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor):
Copied from the transformers library
'''

def __init__(self, penalty: float, additive_penalty: float, _range: int):
def __init__(self, penalty: float, presence_penalty: float, frequency_penalty: float, _range: int):
if not (penalty > 0):
raise ValueError(f"`penalty` has to be strictly positive, but is {penalty}")

self.penalty = penalty
self.additive_penalty = additive_penalty
self.presence_penalty = presence_penalty
self.frequency_penalty = frequency_penalty
self._range = _range

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:

input_ids = input_ids[:, -self._range:]
score = torch.gather(scores, 1, input_ids)

# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
score -= self.additive_penalty
# We loop here because torch.unique() needs to process each row separately in the
# case that batch_size > 1.
for input_ids_row, scores_row in zip(input_ids, scores):
unique_ids, counts = torch.unique(input_ids_row, return_counts=True)
score = torch.gather(scores_row, 0, unique_ids)

# multiplicative repetition penalty
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
scores_row.scatter_(0, unique_ids, score)

# presence_penalty and frequency_penalty
raw_presence_penalty = (counts > 0).to(scores.dtype)
raw_frequency_penalty = counts.to(scores.dtype)
additive_penalty = raw_presence_penalty*self.presence_penalty + raw_frequency_penalty*self.frequency_penalty
scores_row.scatter_add_(0, unique_ids, -additive_penalty)

scores.scatter_(1, input_ids, score)
return scores


Expand Down Expand Up @@ -188,9 +199,10 @@ def get_logits_warper_patch(self, generation_config):

def get_logits_processor_patch(self, **kwargs):
repetition_penalty = kwargs['generation_config'].repetition_penalty
additive_repetition_penalty = kwargs['generation_config'].additive_repetition_penalty
presence_penalty = kwargs['generation_config'].presence_penalty
frequency_penalty = kwargs['generation_config'].frequency_penalty
repetition_penalty_range = kwargs['generation_config'].repetition_penalty_range
do_rep_pen_hijack = (repetition_penalty > 1) or (additive_repetition_penalty > 0)
do_rep_pen_hijack = (repetition_penalty > 1) or (presence_penalty != 0) or (frequency_penalty != 0)
if do_rep_pen_hijack:
# Make sure that a RepetitionPenaltyLogitsProcessor will be created
kwargs['generation_config'].repetition_penalty = 1.1 # must set to some value > 1
Expand All @@ -200,7 +212,7 @@ def get_logits_processor_patch(self, **kwargs):
if do_rep_pen_hijack:
for i in range(len(result)):
if result[i].__class__.__name__ == 'RepetitionPenaltyLogitsProcessor':
result[i] = RepetitionPenaltyLogitsProcessorWithRange(repetition_penalty, additive_repetition_penalty, repetition_penalty_range)
result[i] = RepetitionPenaltyLogitsProcessorWithRange(repetition_penalty, presence_penalty, frequency_penalty, repetition_penalty_range)

return result

Expand All @@ -213,7 +225,8 @@ def generation_config_init_patch(self, **kwargs):
self.mirostat_eta = kwargs.pop("mirostat_eta", 0.1)
self.mirostat_tau = kwargs.pop("mirostat_tau", 5)
self.repetition_penalty_range = kwargs.pop("repetition_penalty_range", 0)
self.additive_repetition_penalty = kwargs.pop("additive_repetition_penalty", 0)
self.presence_penalty = kwargs.pop("presence_penalty", 0)
self.frequency_penalty = kwargs.pop("frequency_penalty", 0)


def hijack_samplers():
Expand Down
2 changes: 1 addition & 1 deletion modules/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def apply_stopping_strings(reply, all_stop_strings):

def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False):
generate_params = {}
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'additive_repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'guidance_scale']:
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'guidance_scale']:
generate_params[k] = state[k]

if state['negative_prompt'] != '':
Expand Down
3 changes: 2 additions & 1 deletion modules/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ def list_interface_input_elements():
'epsilon_cutoff',
'eta_cutoff',
'repetition_penalty',
'additive_repetition_penalty',
'presence_penalty',
'frequency_penalty',
'repetition_penalty_range',
'encoder_repetition_penalty',
'no_repeat_ngram_size',
Expand Down
3 changes: 2 additions & 1 deletion modules/ui_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def create_ui(default_preset):
shared.gradio['top_p'] = gr.Slider(0.0, 1.0, value=generate_params['top_p'], step=0.01, label='top_p')
shared.gradio['top_k'] = gr.Slider(0, 200, value=generate_params['top_k'], step=1, label='top_k')
shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5, value=generate_params['repetition_penalty'], step=0.01, label='repetition_penalty')
shared.gradio['additive_repetition_penalty'] = gr.Slider(0, 4, value=generate_params['additive_repetition_penalty'], step=0.05, label='additive_repetition_penalty')
shared.gradio['presence_penalty'] = gr.Slider(0, 4, value=generate_params['presence_penalty'], step=0.05, label='presence_penalty')
shared.gradio['frequency_penalty'] = gr.Slider(0, 2, value=generate_params['frequency_penalty'], step=0.05, label='frequency_penalty')
shared.gradio['repetition_penalty_range'] = gr.Slider(0, 4096, step=64, value=generate_params['repetition_penalty_range'], label='repetition_penalty_range')
shared.gradio['typical_p'] = gr.Slider(0.0, 1.0, value=generate_params['typical_p'], step=0.01, label='typical_p')
shared.gradio['tfs'] = gr.Slider(0.0, 1.0, value=generate_params['tfs'], step=0.01, label='tfs')
Expand Down

0 comments on commit 72f6fc6

Please sign in to comment.