Skip to content

Commit

Permalink
Update bad_words_ids usage (huggingface#15641)
Browse files Browse the repository at this point in the history
* Improve the parameter `bad_word_ids' usage

* Update the bad_words_ids strategy
  • Loading branch information
ngoquanghuy99 authored and ManuelFay committed Mar 31, 2022
1 parent 5ec59af commit 4069746
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
5 changes: 3 additions & 2 deletions src/transformers/generation_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,9 @@ class NoBadWordsLogitsProcessor(LogitsProcessor):
Args:
bad_words_ids (`List[List[int]]`):
List of list of token ids that are not allowed to be generated. In order to get the tokens of the words
that should not appear in the generated text, use `tokenizer(bad_word, add_prefix_space=True).input_ids`.
List of list of token ids that are not allowed to be generated. In order to get the token ids of the words
that should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True,
add_special_tokens=False).input_ids`.
eos_token_id (`int`):
The id of the *end-of-sequence* token.
"""
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,8 +901,8 @@ def generate(
If set to int > 0, all ngrams of that size that occur in the `encoder_input_ids` cannot occur in the
`decoder_input_ids`.
bad_words_ids(`List[List[int]]`, *optional*):
List of token ids that are not allowed to be generated. In order to get the tokens of the words that
should not appear in the generated text, use `tokenizer(bad_word, add_prefix_space=True,
List of token ids that are not allowed to be generated. In order to get the token ids of the words that
should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True,
add_special_tokens=False).input_ids`.
num_return_sequences(`int`, *optional*, defaults to 1):
The number of independently computed returned sequences for each element in the batch.
Expand Down

0 comments on commit 4069746

Please sign in to comment.