From 6ff6548e4476c43ba2958ac525c99dc0f5bec6dd Mon Sep 17 00:00:00 2001 From: Guillaume B Date: Sun, 26 Jul 2020 10:46:07 +0200 Subject: [PATCH 01/16] Optimized banned token masking --- src/transformers/generation_utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 3c49e15abf7bc8..07f89c6d438a66 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -90,7 +90,10 @@ def postprocess_next_token_scores( if bad_words_ids is not None: # calculate a list of banned tokens according to bad words - banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids) + banned_tokens = calc_banned_bad_words_ids(input_ids.tolist(), bad_words_ids) + banned_mask = torch.zeros_like(scores) + banned_mask[torch.arange(banned_mask.size(0), dtype=torch.long), banned_tokens] = 1 + scores.masked_fill_(banned_mask.bool(), -float("inf")) for i, banned_tokens in enumerate(banned_tokens): scores[i, banned_tokens] = -float("inf") @@ -893,7 +896,7 @@ def _tokens_match(prev_tokens, tokens): bad_words_ids ) - if _tokens_match(prev_input_ids_slice.tolist(), banned_token_seq[:-1]) is False: + if _tokens_match(prev_input_ids_slice, banned_token_seq[:-1]) is False: # if tokens do not match continue continue From 3f046f669ee86fd7479739b70c92a3ff3c40ded4 Mon Sep 17 00:00:00 2001 From: Guillaume B Date: Sun, 26 Jul 2020 11:05:05 +0200 Subject: [PATCH 02/16] Avoid duplicate EOS masking if in bad_words_id --- src/transformers/generation_utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 07f89c6d438a66..36f47605f1a61a 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -89,15 +89,14 @@ def postprocess_next_token_scores( scores[i, banned_tokens] = -float("inf") if bad_words_ids is not None: + # Exclude EOS token (already processed) + bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids)) # calculate a list of banned tokens according to bad words banned_tokens = calc_banned_bad_words_ids(input_ids.tolist(), bad_words_ids) banned_mask = torch.zeros_like(scores) banned_mask[torch.arange(banned_mask.size(0), dtype=torch.long), banned_tokens] = 1 scores.masked_fill_(banned_mask.bool(), -float("inf")) - for i, banned_tokens in enumerate(banned_tokens): - scores[i, banned_tokens] = -float("inf") - return scores @torch.no_grad() From c717232ae6e38de0c23dce0f1bff01eecb24d7c9 Mon Sep 17 00:00:00 2001 From: Guillaume B Date: Mon, 27 Jul 2020 19:03:08 +0200 Subject: [PATCH 03/16] Updated mask generation to handle empty banned token list --- src/transformers/generation_utils.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 36f47605f1a61a..5e324fe6126f26 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -93,9 +93,20 @@ def postprocess_next_token_scores( bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids)) # calculate a list of banned tokens according to bad words banned_tokens = calc_banned_bad_words_ids(input_ids.tolist(), bad_words_ids) - banned_mask = torch.zeros_like(scores) - banned_mask[torch.arange(banned_mask.size(0), dtype=torch.long), banned_tokens] = 1 - scores.masked_fill_(banned_mask.bool(), -float("inf")) + banned_mask_list = [] + for idx, batch_banned_tokens in enumerate(banned_tokens): + for token in batch_banned_tokens: + banned_mask_list.append([idx, token]) + if len(banned_mask_list) > 0: + banned_mask = torch.LongTensor(banned_mask_list) + indices = torch.ones(len(banned_mask)) + banned_mask = ( + torch.sparse.LongTensor(banned_mask.t(), indices, scores.size()) + .to_dense() + .bool() + .to(scores.device) + ) + scores.masked_fill_(banned_mask, -float("inf")) return scores From f0adddb41a80ccaf647fc204a72c646004c632a6 Mon Sep 17 00:00:00 2001 From: Guillaume B Date: Tue, 28 Jul 2020 18:42:04 +0200 Subject: [PATCH 04/16] Addition of unit tests for the updated bad_words_ids masking --- src/transformers/generation_utils.py | 5 +++ tests/test_modeling_common.py | 66 +++++++++++++++++++++++++++- 2 files changed, 69 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 5e324fe6126f26..52be935ad8dd56 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -100,6 +100,11 @@ def postprocess_next_token_scores( if len(banned_mask_list) > 0: banned_mask = torch.LongTensor(banned_mask_list) indices = torch.ones(len(banned_mask)) + # A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates: + # [ 0 1 1 ] + # [ 0 0 0 ] + # [ 1 0 0 ] + banned_mask = ( torch.sparse.LongTensor(banned_mask.t(), indices, scores.size()) .to_dense() diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 097c387543cc5d..d1626615b3f07e 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -20,7 +20,9 @@ import unittest from typing import List -from transformers import is_torch_available +import timeout_decorator + +from transformers import MarianConfig, MarianMTModel, is_torch_available from transformers.testing_utils import require_multigpu, require_torch, slow, torch_device @@ -51,7 +53,6 @@ def _config_zero_init(config): @require_torch class ModelTesterMixin: - model_tester = None all_model_classes = () all_generative_model_classes = () @@ -971,3 +972,64 @@ def test_top_k_top_p_filtering(self): self.assertTrue(torch.allclose(non_inf_expected_output, non_inf_output, atol=1e-12)) self.assertTrue(torch.all(torch.eq(non_inf_expected_idx, non_inf_idx))) + + @require_torch + def test_postprocess_next_token_scores(self): + config = MarianConfig.from_pretrained("Helsinki-NLP/opus-mt-fr-en") + model = MarianMTModel(config=config) + # Initialize an input id tensor with batch size 8 and sequence length 12 + input_ids = torch.arange(0, 96, 1).view((8, 12)) + + bad_words_ids_test_cases = [[[299]], [[23, 24], [54]], [[config.eos_token_id]], []] + masked_scores = [ + [(0, 299), (1, 299), (2, 299), (3, 299), (4, 299), (5, 299), (6, 299), (7, 299)], + [(1, 24), (0, 54), (1, 54), (2, 54), (3, 54), (4, 54), (5, 54), (6, 54), (7, 54)], + [(0, 0), (1, 0), (2, 0), (3, 0), (4, 0), (5, 0), (6, 0), (7, 0)], + [], + ] + + for test_case_index, bad_words_ids in enumerate(bad_words_ids_test_cases): + # Initialize a scores tensor with batch size 8 and vocabulary size 300 + scores = torch.rand((8, 300)) + output = model.postprocess_next_token_scores( + scores, + input_ids, + 0, + bad_words_ids, + 13, + 15, + config.max_length, + config.eos_token_id, + config.repetition_penalty, + 32, + 5, + ) + for masked_score in masked_scores[test_case_index]: + self.assertTrue(output[masked_score[0], masked_score[1]] == -float("inf")) + + @timeout_decorator.timeout(1) + def test_postprocess_next_token_scores_large_bad_words_list(self): + config = MarianConfig.from_pretrained("Helsinki-NLP/opus-mt-fr-en") + model = MarianMTModel(config=config) + # Initialize an input id tensor with batch size 8 and sequence length 12 + input_ids = torch.arange(0, 96, 1).view((8, 12)) + + bad_words_ids = [] + for _ in range(1000): + length_bad_word = random.randint(1, 4) + bad_words_ids.append(random.sample(range(1, 300), length_bad_word)) + + scores = torch.rand((8, 300)) + _ = model.postprocess_next_token_scores( + scores, + input_ids, + 0, + bad_words_ids, + 13, + 15, + config.max_length, + config.eos_token_id, + config.repetition_penalty, + 32, + 5, + ) From f4dabda1144b7c038a07c0bbd7c39ef71d6f487d Mon Sep 17 00:00:00 2001 From: Guillaume B Date: Tue, 28 Jul 2020 18:56:02 +0200 Subject: [PATCH 05/16] Updated timeout handling in `test_postprocess_next_token_scores_large_bad_words_list` unit test --- tests/test_modeling_common.py | 52 +++++++++++++++++++---------------- 1 file changed, 28 insertions(+), 24 deletions(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index d1626615b3f07e..223288cbadc352 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -1007,29 +1007,33 @@ def test_postprocess_next_token_scores(self): for masked_score in masked_scores[test_case_index]: self.assertTrue(output[masked_score[0], masked_score[1]] == -float("inf")) - @timeout_decorator.timeout(1) + @require_torch + @timeout_decorator.timeout(2) def test_postprocess_next_token_scores_large_bad_words_list(self): - config = MarianConfig.from_pretrained("Helsinki-NLP/opus-mt-fr-en") - model = MarianMTModel(config=config) - # Initialize an input id tensor with batch size 8 and sequence length 12 - input_ids = torch.arange(0, 96, 1).view((8, 12)) + try: + config = MarianConfig.from_pretrained("Helsinki-NLP/opus-mt-fr-en") + model = MarianMTModel(config=config) + # Initialize an input id tensor with batch size 8 and sequence length 12 + input_ids = torch.arange(0, 96, 1).view((8, 12)) - bad_words_ids = [] - for _ in range(1000): - length_bad_word = random.randint(1, 4) - bad_words_ids.append(random.sample(range(1, 300), length_bad_word)) - - scores = torch.rand((8, 300)) - _ = model.postprocess_next_token_scores( - scores, - input_ids, - 0, - bad_words_ids, - 13, - 15, - config.max_length, - config.eos_token_id, - config.repetition_penalty, - 32, - 5, - ) + bad_words_ids = [] + for _ in range(1000): + length_bad_word = random.randint(1, 4) + bad_words_ids.append(random.sample(range(1, 300), length_bad_word)) + + scores = torch.rand((8, 300)) + _ = model.postprocess_next_token_scores( + scores, + input_ids, + 0, + bad_words_ids, + 13, + 15, + config.max_length, + config.eos_token_id, + config.repetition_penalty, + 32, + 5, + ) + except OSError: + print("Test timed-out") From e1e0b55f2a5d8861f3eb9e57605fe2f3f2dc8e9a Mon Sep 17 00:00:00 2001 From: Guillaume B Date: Tue, 28 Jul 2020 19:02:36 +0200 Subject: [PATCH 06/16] Updated timeout handling in `test_postprocess_next_token_scores_large_bad_words_list` unit test (timeout does not work on Windows) --- tests/test_modeling_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 223288cbadc352..42d9e5f0cd2403 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -1035,5 +1035,5 @@ def test_postprocess_next_token_scores_large_bad_words_list(self): 32, 5, ) - except OSError: + except timeout_decorator.timeout_decorator.TimeoutError: print("Test timed-out") From 7d3686bce4f685c6adfe4b679422dbec9aefb41b Mon Sep 17 00:00:00 2001 From: Guillaume B Date: Tue, 28 Jul 2020 19:11:59 +0200 Subject: [PATCH 07/16] Moving Marian import to the test context to allow TF only environments to run --- tests/test_modeling_common.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 42d9e5f0cd2403..a20a0fb1f9489e 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -22,7 +22,7 @@ import timeout_decorator -from transformers import MarianConfig, MarianMTModel, is_torch_available +from transformers import is_torch_available from transformers.testing_utils import require_multigpu, require_torch, slow, torch_device @@ -975,6 +975,8 @@ def test_top_k_top_p_filtering(self): @require_torch def test_postprocess_next_token_scores(self): + from transformers import MarianConfig, MarianMTModel + config = MarianConfig.from_pretrained("Helsinki-NLP/opus-mt-fr-en") model = MarianMTModel(config=config) # Initialize an input id tensor with batch size 8 and sequence length 12 @@ -1010,6 +1012,8 @@ def test_postprocess_next_token_scores(self): @require_torch @timeout_decorator.timeout(2) def test_postprocess_next_token_scores_large_bad_words_list(self): + from transformers import MarianConfig, MarianMTModel + try: config = MarianConfig.from_pretrained("Helsinki-NLP/opus-mt-fr-en") model = MarianMTModel(config=config) From bd8e9085db6eaefefbfcd720896b9898e221f511 Mon Sep 17 00:00:00 2001 From: Guillaume B Date: Tue, 28 Jul 2020 19:27:54 +0200 Subject: [PATCH 08/16] Moving imports to torch_available test --- tests/test_modeling_common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index a20a0fb1f9489e..e69d1e1e20edd5 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -40,6 +40,8 @@ MODEL_FOR_MULTIPLE_CHOICE_MAPPING, MODEL_FOR_QUESTION_ANSWERING_MAPPING, top_k_top_p_filtering, + MarianConfig, + MarianMTModel, ) @@ -975,7 +977,6 @@ def test_top_k_top_p_filtering(self): @require_torch def test_postprocess_next_token_scores(self): - from transformers import MarianConfig, MarianMTModel config = MarianConfig.from_pretrained("Helsinki-NLP/opus-mt-fr-en") model = MarianMTModel(config=config) @@ -1012,7 +1013,6 @@ def test_postprocess_next_token_scores(self): @require_torch @timeout_decorator.timeout(2) def test_postprocess_next_token_scores_large_bad_words_list(self): - from transformers import MarianConfig, MarianMTModel try: config = MarianConfig.from_pretrained("Helsinki-NLP/opus-mt-fr-en") From d8a9368bc727e69cb668cfc4357b019499ffc5b1 Mon Sep 17 00:00:00 2001 From: Guillaume B Date: Wed, 29 Jul 2020 19:01:43 +0200 Subject: [PATCH 09/16] Updated operations device and test --- src/transformers/generation_utils.py | 11 ++---- tests/test_modeling_common.py | 51 +++++++++++++--------------- 2 files changed, 27 insertions(+), 35 deletions(-) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 52be935ad8dd56..263fcad73cf186 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -98,19 +98,14 @@ def postprocess_next_token_scores( for token in batch_banned_tokens: banned_mask_list.append([idx, token]) if len(banned_mask_list) > 0: - banned_mask = torch.LongTensor(banned_mask_list) - indices = torch.ones(len(banned_mask)) + banned_mask = torch.LongTensor(banned_mask_list).to(scores.device) + indices = torch.ones(len(banned_mask)).to(scores.device) # A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates: # [ 0 1 1 ] # [ 0 0 0 ] # [ 1 0 0 ] - banned_mask = ( - torch.sparse.LongTensor(banned_mask.t(), indices, scores.size()) - .to_dense() - .bool() - .to(scores.device) - ) + banned_mask = torch.sparse.LongTensor(banned_mask.t(), indices, scores.size()).to_dense().bool() scores.masked_fill_(banned_mask, -float("inf")) return scores diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index e69d1e1e20edd5..aa4b31b41abde9 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -1011,33 +1011,30 @@ def test_postprocess_next_token_scores(self): self.assertTrue(output[masked_score[0], masked_score[1]] == -float("inf")) @require_torch - @timeout_decorator.timeout(2) + @timeout_decorator.timeout(10) def test_postprocess_next_token_scores_large_bad_words_list(self): - try: - config = MarianConfig.from_pretrained("Helsinki-NLP/opus-mt-fr-en") - model = MarianMTModel(config=config) - # Initialize an input id tensor with batch size 8 and sequence length 12 - input_ids = torch.arange(0, 96, 1).view((8, 12)) - - bad_words_ids = [] - for _ in range(1000): - length_bad_word = random.randint(1, 4) - bad_words_ids.append(random.sample(range(1, 300), length_bad_word)) + config = MarianConfig.from_pretrained("Helsinki-NLP/opus-mt-fr-en") + model = MarianMTModel(config=config) + # Initialize an input id tensor with batch size 8 and sequence length 12 + input_ids = torch.arange(0, 96, 1).view((8, 12)) - scores = torch.rand((8, 300)) - _ = model.postprocess_next_token_scores( - scores, - input_ids, - 0, - bad_words_ids, - 13, - 15, - config.max_length, - config.eos_token_id, - config.repetition_penalty, - 32, - 5, - ) - except timeout_decorator.timeout_decorator.TimeoutError: - print("Test timed-out") + bad_words_ids = [] + for _ in range(100): + length_bad_word = random.randint(1, 4) + bad_words_ids.append(random.sample(range(1, 300), length_bad_word)) + + scores = torch.rand((8, 300)) + _ = model.postprocess_next_token_scores( + scores, + input_ids, + 0, + bad_words_ids, + 13, + 15, + config.max_length, + config.eos_token_id, + config.repetition_penalty, + 32, + 5, + ) From 3e0e8fbccb7d10e0139dfdedb0d6773bce3e7f2c Mon Sep 17 00:00:00 2001 From: Guillaume B Date: Wed, 29 Jul 2020 19:35:54 +0200 Subject: [PATCH 10/16] Updated operations device and test --- src/transformers/generation_utils.py | 36 ++++++++++++++++------------ 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 263fcad73cf186..f64a7947f8cef1 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -15,7 +15,7 @@ # limitations under the License. import logging -from typing import Iterable, Optional, Tuple +from typing import Iterable, List, Optional, Tuple import torch from torch import Tensor @@ -93,20 +93,7 @@ def postprocess_next_token_scores( bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids)) # calculate a list of banned tokens according to bad words banned_tokens = calc_banned_bad_words_ids(input_ids.tolist(), bad_words_ids) - banned_mask_list = [] - for idx, batch_banned_tokens in enumerate(banned_tokens): - for token in batch_banned_tokens: - banned_mask_list.append([idx, token]) - if len(banned_mask_list) > 0: - banned_mask = torch.LongTensor(banned_mask_list).to(scores.device) - indices = torch.ones(len(banned_mask)).to(scores.device) - # A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates: - # [ 0 1 1 ] - # [ 0 0 0 ] - # [ 1 0 0 ] - - banned_mask = torch.sparse.LongTensor(banned_mask.t(), indices, scores.size()).to_dense().bool() - scores.masked_fill_(banned_mask, -float("inf")) + set_scores_to_inf_for_banned_tokens(scores, banned_tokens) return scores @@ -917,6 +904,25 @@ def _tokens_match(prev_tokens, tokens): return banned_tokens +def set_scores_to_inf_for_banned_tokens(scores: torch.Tensor, banned_tokens: List[List[int]]) -> None: + banned_mask_list = [] + for idx, batch_banned_tokens in enumerate(banned_tokens): + for token in batch_banned_tokens: + banned_mask_list.append([idx, token]) + if len(banned_mask_list) > 0: + banned_mask = torch.LongTensor(banned_mask_list) + indices = torch.ones(len(banned_mask)) + # A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates: + # [ 0 1 1 ] + # [ 0 0 0 ] + # [ 1 0 0 ] + + banned_mask = ( + torch.sparse.LongTensor(banned_mask.t(), indices, scores.size()).to(scores.device).to_dense().bool() + ) + scores.masked_fill_(banned_mask, -float("inf")) + + def top_k_top_p_filtering( logits: Tensor, top_k: int = 0, From 4f74e2a7760cfd3b6c34eaf6e02c1d20334aa34b Mon Sep 17 00:00:00 2001 From: Guillaume B Date: Wed, 29 Jul 2020 20:55:27 +0200 Subject: [PATCH 11/16] Added docstring and comment for in-place scores modification --- src/transformers/generation_utils.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index f64a7947f8cef1..94f2eb3d681dbc 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -93,6 +93,7 @@ def postprocess_next_token_scores( bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids)) # calculate a list of banned tokens according to bad words banned_tokens = calc_banned_bad_words_ids(input_ids.tolist(), bad_words_ids) + # Modify the scores in place by setting the banned tokens logits to `-inf` set_scores_to_inf_for_banned_tokens(scores, banned_tokens) return scores @@ -905,6 +906,12 @@ def _tokens_match(prev_tokens, tokens): def set_scores_to_inf_for_banned_tokens(scores: torch.Tensor, banned_tokens: List[List[int]]) -> None: + """ Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be + a list of list of banned tokens to ban in the format [[batch index, vocabulary position],...] + Args: + scores: logits distribution of shape (batch size, vocabulary size) + banned_tokens: list of list of tokens to ban of length (batch_size) + """ banned_mask_list = [] for idx, batch_banned_tokens in enumerate(banned_tokens): for token in batch_banned_tokens: From 4a2a4e6babe81cc060cb96df8a592d987ad8e11e Mon Sep 17 00:00:00 2001 From: Guillaume B Date: Wed, 29 Jul 2020 21:10:36 +0200 Subject: [PATCH 12/16] Moving test to own test_generation_utils, use of lighter models for testing --- .../data/test_generation_utils.py | 95 +++++++++++++++++++ src/transformers/generation_utils.py | 23 +++-- tests/test_modeling_common.py | 64 ------------- 3 files changed, 106 insertions(+), 76 deletions(-) create mode 100644 src/transformers/data/test_generation_utils.py diff --git a/src/transformers/data/test_generation_utils.py b/src/transformers/data/test_generation_utils.py new file mode 100644 index 00000000000000..981a8dbd2e4b0d --- /dev/null +++ b/src/transformers/data/test_generation_utils.py @@ -0,0 +1,95 @@ +import random +import unittest + +import timeout_decorator + +from transformers import is_torch_available +from transformers.file_utils import cached_property +from transformers.testing_utils import require_torch + + +if is_torch_available(): + import torch + + from transformers import ( + MarianConfig, + MarianMTModel, + ) + + +@require_torch +class GenerationUtilsTest(unittest.TestCase): + @cached_property + def config_and_model(self): + config = MarianConfig.from_pretrained("sshleifer/tiny-marian-en-de") + return config, MarianMTModel(config) + + @require_torch + def test_postprocess_next_token_scores(self): + config, model = self.config_and_model + # Initialize an input id tensor with batch size 8 and sequence length 12 + input_ids = torch.arange(0, 96, 1).view((8, 12)) + + bad_words_ids_test_cases = [[[299]], [[23, 24], [54]], [[config.eos_token_id]], []] + masked_scores = [ + [(0, 299), (1, 299), (2, 299), (3, 299), (4, 299), (5, 299), (6, 299), (7, 299)], + [(1, 24), (0, 54), (1, 54), (2, 54), (3, 54), (4, 54), (5, 54), (6, 54), (7, 54)], + [ + (0, config.eos_token_id), + (1, config.eos_token_id), + (2, config.eos_token_id), + (3, config.eos_token_id), + (4, config.eos_token_id), + (5, config.eos_token_id), + (6, config.eos_token_id), + (7, config.eos_token_id), + ], + [], + ] + + for test_case_index, bad_words_ids in enumerate(bad_words_ids_test_cases): + # Initialize a scores tensor with batch size 8 and vocabulary size 300 + scores = torch.rand((8, 300)) + output = model.postprocess_next_token_scores( + scores, + input_ids, + 0, + bad_words_ids, + 13, + 15, + config.max_length, + config.eos_token_id, + config.repetition_penalty, + 32, + 5, + ) + for masked_score in masked_scores[test_case_index]: + self.assertTrue(output[masked_score[0], masked_score[1]] == -float("inf")) + + @require_torch + @timeout_decorator.timeout(10) + def test_postprocess_next_token_scores_large_bad_words_list(self): + + config, model = self.config_and_model + # Initialize an input id tensor with batch size 8 and sequence length 12 + input_ids = torch.arange(0, 96, 1).view((8, 12)) + + bad_words_ids = [] + for _ in range(100): + length_bad_word = random.randint(1, 4) + bad_words_ids.append(random.sample(range(1, 300), length_bad_word)) + + scores = torch.rand((8, 300)) + _ = model.postprocess_next_token_scores( + scores, + input_ids, + 0, + bad_words_ids, + 13, + 15, + config.max_length, + config.eos_token_id, + config.repetition_penalty, + 32, + 5, + ) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 94f2eb3d681dbc..c731c01fa9837a 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -916,18 +916,17 @@ def set_scores_to_inf_for_banned_tokens(scores: torch.Tensor, banned_tokens: Lis for idx, batch_banned_tokens in enumerate(banned_tokens): for token in batch_banned_tokens: banned_mask_list.append([idx, token]) - if len(banned_mask_list) > 0: - banned_mask = torch.LongTensor(banned_mask_list) - indices = torch.ones(len(banned_mask)) - # A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates: - # [ 0 1 1 ] - # [ 0 0 0 ] - # [ 1 0 0 ] - - banned_mask = ( - torch.sparse.LongTensor(banned_mask.t(), indices, scores.size()).to(scores.device).to_dense().bool() - ) - scores.masked_fill_(banned_mask, -float("inf")) + if not banned_mask_list: + return + banned_mask = torch.LongTensor(banned_mask_list) + indices = torch.ones(len(banned_mask)) + # A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates: + # [ 0 1 1 ] + # [ 0 0 0 ] + # [ 1 0 0 ] + + banned_mask = torch.sparse.LongTensor(banned_mask.t(), indices, scores.size()).to(scores.device).to_dense().bool() + scores.masked_fill_(banned_mask, -float("inf")) def top_k_top_p_filtering( diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index aa4b31b41abde9..4efae09abfc927 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -974,67 +974,3 @@ def test_top_k_top_p_filtering(self): self.assertTrue(torch.allclose(non_inf_expected_output, non_inf_output, atol=1e-12)) self.assertTrue(torch.all(torch.eq(non_inf_expected_idx, non_inf_idx))) - - @require_torch - def test_postprocess_next_token_scores(self): - - config = MarianConfig.from_pretrained("Helsinki-NLP/opus-mt-fr-en") - model = MarianMTModel(config=config) - # Initialize an input id tensor with batch size 8 and sequence length 12 - input_ids = torch.arange(0, 96, 1).view((8, 12)) - - bad_words_ids_test_cases = [[[299]], [[23, 24], [54]], [[config.eos_token_id]], []] - masked_scores = [ - [(0, 299), (1, 299), (2, 299), (3, 299), (4, 299), (5, 299), (6, 299), (7, 299)], - [(1, 24), (0, 54), (1, 54), (2, 54), (3, 54), (4, 54), (5, 54), (6, 54), (7, 54)], - [(0, 0), (1, 0), (2, 0), (3, 0), (4, 0), (5, 0), (6, 0), (7, 0)], - [], - ] - - for test_case_index, bad_words_ids in enumerate(bad_words_ids_test_cases): - # Initialize a scores tensor with batch size 8 and vocabulary size 300 - scores = torch.rand((8, 300)) - output = model.postprocess_next_token_scores( - scores, - input_ids, - 0, - bad_words_ids, - 13, - 15, - config.max_length, - config.eos_token_id, - config.repetition_penalty, - 32, - 5, - ) - for masked_score in masked_scores[test_case_index]: - self.assertTrue(output[masked_score[0], masked_score[1]] == -float("inf")) - - @require_torch - @timeout_decorator.timeout(10) - def test_postprocess_next_token_scores_large_bad_words_list(self): - - config = MarianConfig.from_pretrained("Helsinki-NLP/opus-mt-fr-en") - model = MarianMTModel(config=config) - # Initialize an input id tensor with batch size 8 and sequence length 12 - input_ids = torch.arange(0, 96, 1).view((8, 12)) - - bad_words_ids = [] - for _ in range(100): - length_bad_word = random.randint(1, 4) - bad_words_ids.append(random.sample(range(1, 300), length_bad_word)) - - scores = torch.rand((8, 300)) - _ = model.postprocess_next_token_scores( - scores, - input_ids, - 0, - bad_words_ids, - 13, - 15, - config.max_length, - config.eos_token_id, - config.repetition_penalty, - 32, - 5, - ) From 568fbb91c6f5dd9c504ce55224b8cd9f84f2908d Mon Sep 17 00:00:00 2001 From: Guillaume B Date: Wed, 29 Jul 2020 21:11:37 +0200 Subject: [PATCH 13/16] removed unneded imports in test_modeling_common --- tests/test_modeling_common.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 4efae09abfc927..6849fc173abaca 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -20,8 +20,6 @@ import unittest from typing import List -import timeout_decorator - from transformers import is_torch_available from transformers.testing_utils import require_multigpu, require_torch, slow, torch_device @@ -40,8 +38,6 @@ MODEL_FOR_MULTIPLE_CHOICE_MAPPING, MODEL_FOR_QUESTION_ANSWERING_MAPPING, top_k_top_p_filtering, - MarianConfig, - MarianMTModel, ) From 2618221182a7dd55aa9f3d19fc6d5840b27c5a91 Mon Sep 17 00:00:00 2001 From: Guillaume B Date: Wed, 29 Jul 2020 21:12:31 +0200 Subject: [PATCH 14/16] revert formatting change for ModelTesterMixin --- tests/test_modeling_common.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 6849fc173abaca..097c387543cc5d 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -51,6 +51,7 @@ def _config_zero_init(config): @require_torch class ModelTesterMixin: + model_tester = None all_model_classes = () all_generative_model_classes = () From 01c0697c05b65485bc47069cde10b1169df2bd0e Mon Sep 17 00:00:00 2001 From: Guillaume B Date: Wed, 29 Jul 2020 22:36:51 +0200 Subject: [PATCH 15/16] Updated caching, simplified eos token id test, removed unnecessary @require_torch --- .../data/test_generation_utils.py | 29 ++++++++----------- 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/src/transformers/data/test_generation_utils.py b/src/transformers/data/test_generation_utils.py index 981a8dbd2e4b0d..043d257ade90f5 100644 --- a/src/transformers/data/test_generation_utils.py +++ b/src/transformers/data/test_generation_utils.py @@ -20,30 +20,25 @@ @require_torch class GenerationUtilsTest(unittest.TestCase): @cached_property - def config_and_model(self): + def config(self): config = MarianConfig.from_pretrained("sshleifer/tiny-marian-en-de") - return config, MarianMTModel(config) + return config + + @cached_property + def model(self): + return MarianMTModel(self.config) - @require_torch def test_postprocess_next_token_scores(self): - config, model = self.config_and_model + config = self.config + model = self.model # Initialize an input id tensor with batch size 8 and sequence length 12 input_ids = torch.arange(0, 96, 1).view((8, 12)) - + eos = config.eos_token_id bad_words_ids_test_cases = [[[299]], [[23, 24], [54]], [[config.eos_token_id]], []] masked_scores = [ [(0, 299), (1, 299), (2, 299), (3, 299), (4, 299), (5, 299), (6, 299), (7, 299)], [(1, 24), (0, 54), (1, 54), (2, 54), (3, 54), (4, 54), (5, 54), (6, 54), (7, 54)], - [ - (0, config.eos_token_id), - (1, config.eos_token_id), - (2, config.eos_token_id), - (3, config.eos_token_id), - (4, config.eos_token_id), - (5, config.eos_token_id), - (6, config.eos_token_id), - (7, config.eos_token_id), - ], + [(0, eos), (1, eos), (2, eos), (3, eos), (4, eos), (5, eos), (6, eos), (7, eos),], [], ] @@ -66,11 +61,11 @@ def test_postprocess_next_token_scores(self): for masked_score in masked_scores[test_case_index]: self.assertTrue(output[masked_score[0], masked_score[1]] == -float("inf")) - @require_torch @timeout_decorator.timeout(10) def test_postprocess_next_token_scores_large_bad_words_list(self): - config, model = self.config_and_model + config = self.config + model = self.model # Initialize an input id tensor with batch size 8 and sequence length 12 input_ids = torch.arange(0, 96, 1).view((8, 12)) From 7d9767a9a69cca004560b05489d329c8fc79bb71 Mon Sep 17 00:00:00 2001 From: Guillaume B Date: Wed, 29 Jul 2020 22:40:51 +0200 Subject: [PATCH 16/16] formatting compliance --- src/transformers/data/test_generation_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/data/test_generation_utils.py b/src/transformers/data/test_generation_utils.py index 043d257ade90f5..f253945bad4264 100644 --- a/src/transformers/data/test_generation_utils.py +++ b/src/transformers/data/test_generation_utils.py @@ -38,7 +38,7 @@ def test_postprocess_next_token_scores(self): masked_scores = [ [(0, 299), (1, 299), (2, 299), (3, 299), (4, 299), (5, 299), (6, 299), (7, 299)], [(1, 24), (0, 54), (1, 54), (2, 54), (3, 54), (4, 54), (5, 54), (6, 54), (7, 54)], - [(0, eos), (1, eos), (2, eos), (3, eos), (4, eos), (5, eos), (6, eos), (7, eos),], + [(0, eos), (1, eos), (2, eos), (3, eos), (4, eos), (5, eos), (6, eos), (7, eos)], [], ]