Skip to content

Commit 7130a22

Browse files
ganteArthurZucker
andauthored
Generate: consistently handle special tokens as tensors (#30624)
* tmp commit * [test_all] mvp * missing not * [test_all] final test fixes * fix musicgen_melody and rag * [test_all] empty commit * PR comments * Update src/transformers/generation/utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
1 parent 5413b89 commit 7130a22

File tree

12 files changed

+297
-191
lines changed

12 files changed

+297
-191
lines changed

src/transformers/generation/beam_search.py

+24-16
Original file line numberDiff line numberDiff line change
@@ -218,8 +218,8 @@ def process(
218218
next_scores: torch.FloatTensor,
219219
next_tokens: torch.LongTensor,
220220
next_indices: torch.LongTensor,
221-
pad_token_id: Optional[int] = None,
222-
eos_token_id: Optional[Union[int, List[int]]] = None,
221+
pad_token_id: Optional[Union[int, torch.Tensor]] = None,
222+
eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None,
223223
beam_indices: Optional[torch.LongTensor] = None,
224224
group_index: Optional[int] = 0,
225225
decoder_prompt_len: Optional[int] = 0,
@@ -245,8 +245,10 @@ def process(
245245
next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
246246
next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)
247247

248-
if isinstance(eos_token_id, int):
249-
eos_token_id = [eos_token_id]
248+
if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
249+
if isinstance(eos_token_id, int):
250+
eos_token_id = [eos_token_id]
251+
eos_token_id = torch.tensor(eos_token_id)
250252

251253
for batch_idx in range(batch_size):
252254
batch_group_idx = batch_idx * self.num_beam_groups + group_index
@@ -322,15 +324,17 @@ def finalize(
322324
final_beam_tokens: torch.LongTensor,
323325
final_beam_indices: torch.LongTensor,
324326
max_length: int,
325-
pad_token_id: Optional[int] = None,
326-
eos_token_id: Optional[Union[int, List[int]]] = None,
327+
pad_token_id: Optional[Union[int, torch.Tensor]] = None,
328+
eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None,
327329
beam_indices: Optional[torch.LongTensor] = None,
328330
decoder_prompt_len: Optional[int] = 0,
329331
) -> Tuple[torch.LongTensor]:
330332
batch_size = len(self._beam_hyps) // self.num_beam_groups
331333

332-
if isinstance(eos_token_id, int):
333-
eos_token_id = [eos_token_id]
334+
if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
335+
if isinstance(eos_token_id, int):
336+
eos_token_id = [eos_token_id]
337+
eos_token_id = torch.tensor(eos_token_id)
334338

335339
# finalize all open beam hypotheses and add to generated hypotheses
336340
for batch_group_idx, beam_hyp in enumerate(self._beam_hyps):
@@ -513,8 +517,8 @@ def process(
513517
next_tokens: torch.LongTensor,
514518
next_indices: torch.LongTensor,
515519
scores_for_all_vocab: torch.FloatTensor,
516-
pad_token_id: Optional[int] = None,
517-
eos_token_id: Optional[Union[int, List[int]]] = None,
520+
pad_token_id: Optional[Union[int, torch.Tensor]] = None,
521+
eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None,
518522
beam_indices: Optional[torch.LongTensor] = None,
519523
decoder_prompt_len: Optional[int] = 0,
520524
) -> Tuple[torch.Tensor]:
@@ -578,8 +582,10 @@ def process(
578582
next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
579583
next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)
580584

581-
if isinstance(eos_token_id, int):
582-
eos_token_id = [eos_token_id]
585+
if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
586+
if isinstance(eos_token_id, int):
587+
eos_token_id = [eos_token_id]
588+
eos_token_id = torch.tensor(eos_token_id)
583589

584590
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
585591
if self._done[batch_idx]:
@@ -811,15 +817,17 @@ def finalize(
811817
final_beam_tokens: torch.LongTensor,
812818
final_beam_indices: torch.LongTensor,
813819
max_length: int,
814-
pad_token_id: Optional[int] = None,
815-
eos_token_id: Optional[Union[int, List[int]]] = None,
820+
pad_token_id: Optional[Union[int, torch.Tensor]] = None,
821+
eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None,
816822
beam_indices: Optional[torch.LongTensor] = None,
817823
decoder_prompt_len: Optional[int] = 0,
818824
) -> Tuple[torch.LongTensor]:
819825
batch_size = len(self._beam_hyps)
820826

821-
if isinstance(eos_token_id, int):
822-
eos_token_id = [eos_token_id]
827+
if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
828+
if isinstance(eos_token_id, int):
829+
eos_token_id = [eos_token_id]
830+
eos_token_id = torch.tensor(eos_token_id)
823831

824832
# finalize all open beam hypotheses and add to generated hypotheses
825833
for batch_idx, beam_hyp in enumerate(self._beam_hyps):

src/transformers/generation/logits_process.py

+81-56
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,8 @@ class MinLengthLogitsProcessor(LogitsProcessor):
108108
Args:
109109
min_length (`int`):
110110
The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`.
111-
eos_token_id (`Union[int, List[int]]`):
112-
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
111+
eos_token_id (`Union[int, List[int], torch.Tensor]`):
112+
The id(s) of the *end-of-sequence* token.
113113
114114
Examples:
115115
@@ -137,23 +137,23 @@ class MinLengthLogitsProcessor(LogitsProcessor):
137137
```
138138
"""
139139

140-
def __init__(self, min_length: int, eos_token_id: Union[int, List[int]]):
140+
def __init__(self, min_length: int, eos_token_id: Union[int, List[int], torch.Tensor]):
141141
if not isinstance(min_length, int) or min_length < 0:
142142
raise ValueError(f"`min_length` has to be a non-negative integer, but is {min_length}")
143143

144-
if isinstance(eos_token_id, int):
145-
eos_token_id = [eos_token_id]
146-
if not all(isinstance(i, int) for i in eos_token_id) or any(i < 0 for i in eos_token_id):
147-
logger.warning(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")
144+
if not isinstance(eos_token_id, torch.Tensor):
145+
if isinstance(eos_token_id, int):
146+
eos_token_id = [eos_token_id]
147+
eos_token_id = torch.tensor(eos_token_id)
148148

149149
self.min_length = min_length
150150
self.eos_token_id = eos_token_id
151151

152152
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
153153
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
154154
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
155-
eos_token_id = torch.tensor(self.eos_token_id, device=scores.device)
156-
eos_token_mask = torch.isin(vocab_tensor, eos_token_id)
155+
self.eos_token_id = self.eos_token_id.to(scores.device)
156+
eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id)
157157
scores_processed = scores.clone()
158158
if input_ids.shape[-1] < self.min_length:
159159
scores_processed = torch.where(eos_token_mask, -math.inf, scores)
@@ -171,8 +171,8 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
171171
input length.
172172
min_new_tokens (`int`):
173173
The minimum *new* tokens length below which the score of `eos_token_id` is set to `-float("Inf")`.
174-
eos_token_id (`Union[int, List[int]]`):
175-
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
174+
eos_token_id (`Union[int, List[int], torch.Tensor]`):
175+
The id(s) of the *end-of-sequence* token.
176176
177177
Examples:
178178
@@ -195,18 +195,20 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
195195
```
196196
"""
197197

198-
def __init__(self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id: Union[int, List[int]]):
198+
def __init__(
199+
self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id: Union[int, List[int], torch.Tensor]
200+
):
199201
for arg_name, arg_value in [
200202
("prompt_length_to_skip", prompt_length_to_skip),
201203
("min_new_tokens", min_new_tokens),
202204
]:
203205
if not isinstance(arg_value, int) or arg_value < 0:
204206
raise ValueError(f"`{arg_name}` has to be a positive integer, but is {arg_value}")
205207

206-
if isinstance(eos_token_id, int):
207-
eos_token_id = [eos_token_id]
208-
if not all(isinstance(i, int) for i in eos_token_id) or any(i < 0 for i in eos_token_id):
209-
logger.warning(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")
208+
if not isinstance(eos_token_id, torch.Tensor):
209+
if isinstance(eos_token_id, int):
210+
eos_token_id = [eos_token_id]
211+
eos_token_id = torch.tensor(eos_token_id)
210212

211213
self.prompt_length_to_skip = prompt_length_to_skip
212214
self.min_new_tokens = min_new_tokens
@@ -217,8 +219,8 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
217219
new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip
218220
scores_processed = scores.clone()
219221
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
220-
eos_token_id = torch.tensor(self.eos_token_id, device=scores.device)
221-
eos_token_mask = torch.isin(vocab_tensor, eos_token_id)
222+
self.eos_token_id = self.eos_token_id.to(scores.device)
223+
eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id)
222224
if new_tokens_length < self.min_new_tokens:
223225
scores_processed = torch.where(eos_token_mask, -math.inf, scores)
224226

@@ -1195,8 +1197,8 @@ class NoBadWordsLogitsProcessor(SequenceBiasLogitsProcessor):
11951197
Args:
11961198
bad_words_ids (`List[List[int]]`):
11971199
List of list of token ids that are not allowed to be generated.
1198-
eos_token_id (`Union[int, List[int]]`):
1199-
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
1200+
eos_token_id (`Union[int, List[int], torch.Tensor]`, *optional*):
1201+
The id(s) of the *end-of-sequence* token.
12001202
12011203
Examples:
12021204
@@ -1233,18 +1235,22 @@ class NoBadWordsLogitsProcessor(SequenceBiasLogitsProcessor):
12331235
```
12341236
"""
12351237

1236-
def __init__(self, bad_words_ids: List[List[int]], eos_token_id: Union[int, List[int]]):
1238+
def __init__(
1239+
self, bad_words_ids: List[List[int]], eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None
1240+
):
12371241
self.bad_word_ids = bad_words_ids
12381242
self._validate_arguments()
12391243

12401244
# Filter EOS token from bad_words_ids
1241-
if eos_token_id is None:
1242-
eos_token_id = []
1243-
if isinstance(eos_token_id, int):
1244-
eos_token_id = [eos_token_id]
1245-
bad_words_ids = list(
1246-
filter(lambda bad_token_seq: all(bad_token_seq != [i] for i in eos_token_id), bad_words_ids)
1247-
)
1245+
if eos_token_id is not None:
1246+
if not isinstance(eos_token_id, torch.Tensor):
1247+
if isinstance(eos_token_id, int):
1248+
eos_token_id = [eos_token_id]
1249+
eos_token_id = torch.tensor(eos_token_id)
1250+
1251+
bad_words_ids = list(
1252+
filter(lambda bad_token_seq: all(bad_token_seq != [i] for i in eos_token_id), bad_words_ids)
1253+
)
12481254

12491255
# Forbidding a sequence is equivalent to setting its bias to -inf
12501256
sequence_bias = {tuple(sequence): float("-inf") for sequence in bad_words_ids}
@@ -1522,9 +1528,8 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
15221528
Args:
15231529
max_length (`int`):
15241530
The maximum length of the sequence to be generated.
1525-
eos_token_id (`Union[int, List[int]]`):
1526-
The id of the token to force as the last generated token when `max_length` is reached. Optionally, use a
1527-
list to set multiple *end-of-sequence* tokens.
1531+
eos_token_id (`Union[int, List[int], torch.Tensor]`):
1532+
The id(s) of the *end-of-sequence* token.
15281533
15291534
Examples:
15301535
@@ -1548,15 +1553,22 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
15481553
```
15491554
"""
15501555

1551-
def __init__(self, max_length: int, eos_token_id: Union[int, List[int]]):
1556+
def __init__(self, max_length: int, eos_token_id: Union[int, List[int], torch.Tensor]):
15521557
self.max_length = max_length
1553-
if isinstance(eos_token_id, int):
1554-
eos_token_id = [eos_token_id]
1558+
1559+
if not isinstance(eos_token_id, torch.Tensor):
1560+
if isinstance(eos_token_id, int):
1561+
eos_token_id = [eos_token_id]
1562+
eos_token_id = torch.tensor(eos_token_id)
15551563
self.eos_token_id = eos_token_id
15561564

1565+
if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any():
1566+
raise ValueError(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")
1567+
15571568
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
15581569
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
15591570
cur_len = input_ids.shape[-1]
1571+
self.eos_token_id = self.eos_token_id.to(scores.device)
15601572
scores_processed = scores
15611573
if cur_len == self.max_length - 1:
15621574
scores_processed = torch.full_like(scores, -math.inf)
@@ -1595,8 +1607,8 @@ class ExponentialDecayLengthPenalty(LogitsProcessor):
15951607
exponential_decay_length_penalty (`tuple(int, float)`):
15961608
This tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty
15971609
starts and `decay_factor` represents the factor of exponential decay
1598-
eos_token_id (`Union[int, List[int]]`):
1599-
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
1610+
eos_token_id (`Union[int, List[int], torch.Tensor]`):
1611+
The id(s) of the *end-of-sequence* token.
16001612
input_ids_seq_length (`int`):
16011613
The length of the input sequence.
16021614
@@ -1656,27 +1668,33 @@ class ExponentialDecayLengthPenalty(LogitsProcessor):
16561668
def __init__(
16571669
self,
16581670
exponential_decay_length_penalty: Tuple[int, float],
1659-
eos_token_id: Union[int, List[int]],
1671+
eos_token_id: Union[int, List[int], torch.Tensor],
16601672
input_ids_seq_length: int,
16611673
):
16621674
self.regulation_start = exponential_decay_length_penalty[0] + input_ids_seq_length
16631675
self.regulation_factor = exponential_decay_length_penalty[1]
1664-
if isinstance(eos_token_id, int):
1665-
eos_token_id = [eos_token_id]
1676+
1677+
if not isinstance(eos_token_id, torch.Tensor):
1678+
if isinstance(eos_token_id, int):
1679+
eos_token_id = [eos_token_id]
1680+
eos_token_id = torch.tensor(eos_token_id)
16661681
self.eos_token_id = eos_token_id
16671682

1683+
if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any():
1684+
raise ValueError(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")
1685+
16681686
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
16691687
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
16701688
cur_len = input_ids.shape[-1]
1689+
self.eos_token_id = self.eos_token_id.to(scores.device)
16711690
penalties = torch.zeros_like(scores)
16721691
scores_processed = scores
16731692
if cur_len > self.regulation_start:
1674-
for i in self.eos_token_id:
1675-
penalty_idx = cur_len - self.regulation_start
1676-
# To support negative logits we compute the penalty of the absolute value and add to the original logit
1677-
penalty = torch.abs(scores[:, i]) * (pow(self.regulation_factor, penalty_idx) - 1)
1678-
penalties[:, i] = penalty
1679-
scores_processed = scores + penalties
1693+
penalty_idx = cur_len - self.regulation_start
1694+
# To support negative logits we compute the penalty of the absolute value and add to the original logit
1695+
penalty = torch.abs(scores[:, self.eos_token_id]) * (pow(self.regulation_factor, penalty_idx) - 1)
1696+
penalties[:, self.eos_token_id] = penalty
1697+
scores_processed = scores + penalties
16801698
return scores_processed
16811699

16821700

@@ -1753,7 +1771,7 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor):
17531771
"""
17541772

17551773
def __init__(self, begin_suppress_tokens, begin_index):
1756-
self.begin_suppress_tokens = list(begin_suppress_tokens)
1774+
self.begin_suppress_tokens = torch.tensor(list(begin_suppress_tokens))
17571775
self.begin_index = begin_index
17581776

17591777
def set_begin_index(self, begin_index):
@@ -1762,8 +1780,8 @@ def set_begin_index(self, begin_index):
17621780
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
17631781
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
17641782
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
1765-
begin_suppress_tokens = torch.tensor(self.begin_suppress_tokens, device=scores.device)
1766-
suppress_token_mask = torch.isin(vocab_tensor, begin_suppress_tokens)
1783+
self.begin_suppress_tokens = self.begin_suppress_tokens.to(scores.device)
1784+
suppress_token_mask = torch.isin(vocab_tensor, self.begin_suppress_tokens)
17671785
scores_processed = scores
17681786
if input_ids.shape[-1] == self.begin_index:
17691787
scores_processed = torch.where(suppress_token_mask, -float("inf"), scores)
@@ -1801,13 +1819,13 @@ class SuppressTokensLogitsProcessor(LogitsProcessor):
18011819
"""
18021820

18031821
def __init__(self, suppress_tokens):
1804-
self.suppress_tokens = list(suppress_tokens)
1822+
self.suppress_tokens = torch.tensor(list(suppress_tokens))
18051823

18061824
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
18071825
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
18081826
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
1809-
suppress_tokens = torch.tensor(self.suppress_tokens, device=scores.device)
1810-
suppress_token_mask = torch.isin(vocab_tensor, suppress_tokens)
1827+
self.suppress_tokens = self.suppress_tokens.to(scores.device)
1828+
suppress_token_mask = torch.isin(vocab_tensor, self.suppress_tokens)
18111829
scores = torch.where(suppress_token_mask, -float("inf"), scores)
18121830
return scores
18131831

@@ -2268,23 +2286,30 @@ class BarkEosPrioritizerLogitsProcessor(LogitsProcessor):
22682286
</Tip>
22692287
22702288
Args:
2271-
eos_token_id (`Union[int, List[int]]`):
2272-
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
2289+
eos_token_id (`Union[int, List[int], torch.Tensor]`):
2290+
The id(s) of the *end-of-sequence* token.
22732291
min_eos_p (`float`, *optional*):
22742292
Minimum end of speech threshold.
22752293
"""
22762294

2277-
def __init__(self, eos_token_id: Union[int, List[int]], min_eos_p: float):
2278-
if isinstance(eos_token_id, int):
2279-
eos_token_id = [eos_token_id]
2295+
def __init__(self, eos_token_id: Union[int, List[int], torch.Tensor], min_eos_p: float):
2296+
if not isinstance(eos_token_id, torch.Tensor):
2297+
if isinstance(eos_token_id, int):
2298+
eos_token_id = [eos_token_id]
2299+
eos_token_id = torch.tensor(eos_token_id)
22802300
self.eos_token_id = eos_token_id
2301+
2302+
if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any():
2303+
raise ValueError(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")
2304+
22812305
if min_eos_p is not None and min_eos_p <= 0:
22822306
raise ValueError(f"`min_eos_p` has to be a positive float, but is {min_eos_p}")
22832307
self.min_eos_p = min_eos_p
22842308

22852309
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
22862310
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
22872311
scores_processed = scores
2312+
self.eos_token_id = self.eos_token_id.to(scores.device)
22882313
if self.min_eos_p:
22892314
probs = torch.nn.functional.softmax(scores.float(), dim=-1)
22902315
# create scores full of -inf except for the eos_token_id

0 commit comments

Comments
 (0)