Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix single letter stop strings #31448

Merged
merged 5 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/transformers/generation/stopping_criteria.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,10 +372,11 @@ def _stop_string_create_embedding_vec(token_list, token_indices, stop_strings) -
token_valid_positions, token_end_overlaps = StopStringCriteria._stop_string_get_matching_positions(
token_list, token_indices, stop_strings
)

max_valid_positions = max(
len(val) for positions in token_valid_positions.values() for val in positions.values()
)
all_valid_positions = [len(val) for positions in token_valid_positions.values() for val in positions.values()]
# In some cases, tokens may have no valid internal positions (such as single-character stop strings), so
# we need a fallback to handle this case
max_valid_positions = max(all_valid_positions) if all_valid_positions else 1
# There should always be at least one valid end_len, however, so no fallback needed here
max_valid_end_lens = max(len(val) for positions in token_end_overlaps.values() for val in positions.values())
vec_size = len(stop_strings) * (max_valid_positions + max_valid_end_lens) + 1
gather_vec = np.full((len(token_list), vec_size), dtype=np.int32, fill_value=-1)
Expand Down
18 changes: 18 additions & 0 deletions tests/generation/test_stopping_criteria.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,24 @@ def test_stop_string_embedding_vecs(self):
token_lengths = embedding_vec[:, 2].tolist()
self.assertEqual(token_lengths, [len(token) for token in token_list])

def test_single_letter_stop_string(self):
true_strings = ["a", "baa", "abc"] # "abc" is a single token
false_strings = ["abbbbbbb", "b"] # "abbbbbbb" is split into multiple tokens
stop_strings = ["a"]
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "left"

true_input_ids = tokenizer(true_strings, return_tensors="pt", padding="longest", add_special_tokens=False)
false_input_ids = tokenizer(false_strings, return_tensors="pt", padding="longest", add_special_tokens=False)

scores = None
criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=stop_strings)
for input_ids in true_input_ids["input_ids"]:
self.assertTrue(criteria(input_ids.unsqueeze(0), scores))
for input_ids in false_input_ids["input_ids"]:
self.assertFalse(criteria(input_ids.unsqueeze(0), scores))

def test_criterias_per_row(self):
text = "They completed the challenging puzzle, revealing the hidden image at the end"
stop_strings = ["end"]
Expand Down
Loading