From 28316d0e8b4726c39bf75572c7788792c7da9ce9 Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 18 Jun 2024 14:07:16 +0100 Subject: [PATCH] Fix single letter stop strings (#31448) * Fix single letter stop strings * Change the 0 to a 1 to avoid potential empty vector headaches later * Restructure for clarity * Update tests/generation/test_stopping_criteria.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Add the unsqueeze --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- .../generation/stopping_criteria.py | 9 +++++---- tests/generation/test_stopping_criteria.py | 18 ++++++++++++++++++ 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index 9cc485e4601c6e..b1bf3dee9ae1d9 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -372,10 +372,11 @@ def _stop_string_create_embedding_vec(token_list, token_indices, stop_strings) - token_valid_positions, token_end_overlaps = StopStringCriteria._stop_string_get_matching_positions( token_list, token_indices, stop_strings ) - - max_valid_positions = max( - len(val) for positions in token_valid_positions.values() for val in positions.values() - ) + all_valid_positions = [len(val) for positions in token_valid_positions.values() for val in positions.values()] + # In some cases, tokens may have no valid internal positions (such as single-character stop strings), so + # we need a fallback to handle this case + max_valid_positions = max(all_valid_positions) if all_valid_positions else 1 + # There should always be at least one valid end_len, however, so no fallback needed here max_valid_end_lens = max(len(val) for positions in token_end_overlaps.values() for val in positions.values()) vec_size = len(stop_strings) * (max_valid_positions + max_valid_end_lens) + 1 gather_vec = np.full((len(token_list), vec_size), dtype=np.int32, fill_value=-1) diff --git a/tests/generation/test_stopping_criteria.py b/tests/generation/test_stopping_criteria.py index 1a22491b9aa0f6..ddf9a1c9379ea2 100644 --- a/tests/generation/test_stopping_criteria.py +++ b/tests/generation/test_stopping_criteria.py @@ -208,6 +208,24 @@ def test_stop_string_embedding_vecs(self): token_lengths = embedding_vec[:, 2].tolist() self.assertEqual(token_lengths, [len(token) for token in token_list]) + def test_single_letter_stop_string(self): + true_strings = ["a", "baa", "abc"] # "abc" is a single token + false_strings = ["abbbbbbb", "b"] # "abbbbbbb" is split into multiple tokens + stop_strings = ["a"] + tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + tokenizer.pad_token_id = tokenizer.eos_token_id + tokenizer.padding_side = "left" + + true_input_ids = tokenizer(true_strings, return_tensors="pt", padding="longest", add_special_tokens=False) + false_input_ids = tokenizer(false_strings, return_tensors="pt", padding="longest", add_special_tokens=False) + + scores = None + criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=stop_strings) + for input_ids in true_input_ids["input_ids"]: + self.assertTrue(criteria(input_ids.unsqueeze(0), scores)) + for input_ids in false_input_ids["input_ids"]: + self.assertFalse(criteria(input_ids.unsqueeze(0), scores)) + def test_criterias_per_row(self): text = "They completed the challenging puzzle, revealing the hidden image at the end" stop_strings = ["end"]