diff --git a/src/transformers/data/test_generation_utils.py b/src/transformers/data/test_generation_utils.py new file mode 100644 index 00000000000000..f253945bad4264 --- /dev/null +++ b/src/transformers/data/test_generation_utils.py @@ -0,0 +1,90 @@ +import random +import unittest + +import timeout_decorator + +from transformers import is_torch_available +from transformers.file_utils import cached_property +from transformers.testing_utils import require_torch + + +if is_torch_available(): + import torch + + from transformers import ( + MarianConfig, + MarianMTModel, + ) + + +@require_torch +class GenerationUtilsTest(unittest.TestCase): + @cached_property + def config(self): + config = MarianConfig.from_pretrained("sshleifer/tiny-marian-en-de") + return config + + @cached_property + def model(self): + return MarianMTModel(self.config) + + def test_postprocess_next_token_scores(self): + config = self.config + model = self.model + # Initialize an input id tensor with batch size 8 and sequence length 12 + input_ids = torch.arange(0, 96, 1).view((8, 12)) + eos = config.eos_token_id + bad_words_ids_test_cases = [[[299]], [[23, 24], [54]], [[config.eos_token_id]], []] + masked_scores = [ + [(0, 299), (1, 299), (2, 299), (3, 299), (4, 299), (5, 299), (6, 299), (7, 299)], + [(1, 24), (0, 54), (1, 54), (2, 54), (3, 54), (4, 54), (5, 54), (6, 54), (7, 54)], + [(0, eos), (1, eos), (2, eos), (3, eos), (4, eos), (5, eos), (6, eos), (7, eos)], + [], + ] + + for test_case_index, bad_words_ids in enumerate(bad_words_ids_test_cases): + # Initialize a scores tensor with batch size 8 and vocabulary size 300 + scores = torch.rand((8, 300)) + output = model.postprocess_next_token_scores( + scores, + input_ids, + 0, + bad_words_ids, + 13, + 15, + config.max_length, + config.eos_token_id, + config.repetition_penalty, + 32, + 5, + ) + for masked_score in masked_scores[test_case_index]: + self.assertTrue(output[masked_score[0], masked_score[1]] == -float("inf")) + + @timeout_decorator.timeout(10) + def test_postprocess_next_token_scores_large_bad_words_list(self): + + config = self.config + model = self.model + # Initialize an input id tensor with batch size 8 and sequence length 12 + input_ids = torch.arange(0, 96, 1).view((8, 12)) + + bad_words_ids = [] + for _ in range(100): + length_bad_word = random.randint(1, 4) + bad_words_ids.append(random.sample(range(1, 300), length_bad_word)) + + scores = torch.rand((8, 300)) + _ = model.postprocess_next_token_scores( + scores, + input_ids, + 0, + bad_words_ids, + 13, + 15, + config.max_length, + config.eos_token_id, + config.repetition_penalty, + 32, + 5, + ) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 3c49e15abf7bc8..c731c01fa9837a 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -15,7 +15,7 @@ # limitations under the License. import logging -from typing import Iterable, Optional, Tuple +from typing import Iterable, List, Optional, Tuple import torch from torch import Tensor @@ -89,11 +89,12 @@ def postprocess_next_token_scores( scores[i, banned_tokens] = -float("inf") if bad_words_ids is not None: + # Exclude EOS token (already processed) + bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids)) # calculate a list of banned tokens according to bad words - banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids) - - for i, banned_tokens in enumerate(banned_tokens): - scores[i, banned_tokens] = -float("inf") + banned_tokens = calc_banned_bad_words_ids(input_ids.tolist(), bad_words_ids) + # Modify the scores in place by setting the banned tokens logits to `-inf` + set_scores_to_inf_for_banned_tokens(scores, banned_tokens) return scores @@ -893,7 +894,7 @@ def _tokens_match(prev_tokens, tokens): bad_words_ids ) - if _tokens_match(prev_input_ids_slice.tolist(), banned_token_seq[:-1]) is False: + if _tokens_match(prev_input_ids_slice, banned_token_seq[:-1]) is False: # if tokens do not match continue continue @@ -904,6 +905,30 @@ def _tokens_match(prev_tokens, tokens): return banned_tokens +def set_scores_to_inf_for_banned_tokens(scores: torch.Tensor, banned_tokens: List[List[int]]) -> None: + """ Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be + a list of list of banned tokens to ban in the format [[batch index, vocabulary position],...] + Args: + scores: logits distribution of shape (batch size, vocabulary size) + banned_tokens: list of list of tokens to ban of length (batch_size) + """ + banned_mask_list = [] + for idx, batch_banned_tokens in enumerate(banned_tokens): + for token in batch_banned_tokens: + banned_mask_list.append([idx, token]) + if not banned_mask_list: + return + banned_mask = torch.LongTensor(banned_mask_list) + indices = torch.ones(len(banned_mask)) + # A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates: + # [ 0 1 1 ] + # [ 0 0 0 ] + # [ 1 0 0 ] + + banned_mask = torch.sparse.LongTensor(banned_mask.t(), indices, scores.size()).to(scores.device).to_dense().bool() + scores.masked_fill_(banned_mask, -float("inf")) + + def top_k_top_p_filtering( logits: Tensor, top_k: int = 0,