From fbff27623a69fe90fa06360811d19cb0312f0ce7 Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 30 Aug 2024 16:26:26 +0100 Subject: [PATCH] Add warning for stop string edge case (#33169) * Add warning for edge case * make fixup --- src/transformers/generation/stopping_criteria.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index b8d6540ca2f793..7e98b11cf01a1c 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -348,7 +348,14 @@ def _stop_string_create_embedding_vec(token_list, token_indices, stop_strings) - # we need a fallback to handle this case max_valid_positions = max(all_valid_positions) if all_valid_positions else 1 # There should always be at least one valid end_len, however, so no fallback needed here - max_valid_end_lens = max(len(val) for positions in token_end_overlaps.values() for val in positions.values()) + valid_end_lens = [len(val) for positions in token_end_overlaps.values() for val in positions.values()] + if not valid_end_lens: + raise ValueError( + "Stop string preprocessing was unable to identify tokens matching one or more of the " + "supplied stop string(s). This is most often caused by the stop " + "strings containing unusual characters that are not in the tokenizer vocabulary." + ) + max_valid_end_lens = max(valid_end_lens) vec_size = len(stop_strings) * (max_valid_positions + max_valid_end_lens) + 1 gather_vec = np.full((len(token_list), vec_size), dtype=np.int32, fill_value=-1)