@@ -108,8 +108,8 @@ class MinLengthLogitsProcessor(LogitsProcessor):
108
108
Args:
109
109
min_length (`int`):
110
110
The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`.
111
- eos_token_id (`Union[int, List[int]]`):
112
- The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens .
111
+ eos_token_id (`Union[int, List[int], torch.Tensor ]`):
112
+ The id(s) of the *end-of-sequence* token.
113
113
114
114
Examples:
115
115
@@ -137,23 +137,23 @@ class MinLengthLogitsProcessor(LogitsProcessor):
137
137
```
138
138
"""
139
139
140
- def __init__ (self , min_length : int , eos_token_id : Union [int , List [int ]]):
140
+ def __init__ (self , min_length : int , eos_token_id : Union [int , List [int ], torch . Tensor ]):
141
141
if not isinstance (min_length , int ) or min_length < 0 :
142
142
raise ValueError (f"`min_length` has to be a non-negative integer, but is { min_length } " )
143
143
144
- if isinstance (eos_token_id , int ):
145
- eos_token_id = [ eos_token_id ]
146
- if not all ( isinstance ( i , int ) for i in eos_token_id ) or any ( i < 0 for i in eos_token_id ):
147
- logger . warning ( f"` eos_token_id` has to be a list of positive integers, but is { eos_token_id } " )
144
+ if not isinstance (eos_token_id , torch . Tensor ):
145
+ if isinstance ( eos_token_id , int ):
146
+ eos_token_id = [ eos_token_id ]
147
+ eos_token_id = torch . tensor ( eos_token_id )
148
148
149
149
self .min_length = min_length
150
150
self .eos_token_id = eos_token_id
151
151
152
152
@add_start_docstrings (LOGITS_PROCESSOR_INPUTS_DOCSTRING )
153
153
def __call__ (self , input_ids : torch .LongTensor , scores : torch .FloatTensor ) -> torch .FloatTensor :
154
154
vocab_tensor = torch .arange (scores .shape [- 1 ], device = scores .device )
155
- eos_token_id = torch . tensor ( self .eos_token_id , device = scores .device )
156
- eos_token_mask = torch .isin (vocab_tensor , eos_token_id )
155
+ self . eos_token_id = self .eos_token_id . to ( scores .device )
156
+ eos_token_mask = torch .isin (vocab_tensor , self . eos_token_id )
157
157
scores_processed = scores .clone ()
158
158
if input_ids .shape [- 1 ] < self .min_length :
159
159
scores_processed = torch .where (eos_token_mask , - math .inf , scores )
@@ -171,8 +171,8 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
171
171
input length.
172
172
min_new_tokens (`int`):
173
173
The minimum *new* tokens length below which the score of `eos_token_id` is set to `-float("Inf")`.
174
- eos_token_id (`Union[int, List[int]]`):
175
- The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens .
174
+ eos_token_id (`Union[int, List[int], torch.Tensor ]`):
175
+ The id(s) of the *end-of-sequence* token.
176
176
177
177
Examples:
178
178
@@ -195,18 +195,20 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
195
195
```
196
196
"""
197
197
198
- def __init__ (self , prompt_length_to_skip : int , min_new_tokens : int , eos_token_id : Union [int , List [int ]]):
198
+ def __init__ (
199
+ self , prompt_length_to_skip : int , min_new_tokens : int , eos_token_id : Union [int , List [int ], torch .Tensor ]
200
+ ):
199
201
for arg_name , arg_value in [
200
202
("prompt_length_to_skip" , prompt_length_to_skip ),
201
203
("min_new_tokens" , min_new_tokens ),
202
204
]:
203
205
if not isinstance (arg_value , int ) or arg_value < 0 :
204
206
raise ValueError (f"`{ arg_name } ` has to be a positive integer, but is { arg_value } " )
205
207
206
- if isinstance (eos_token_id , int ):
207
- eos_token_id = [ eos_token_id ]
208
- if not all ( isinstance ( i , int ) for i in eos_token_id ) or any ( i < 0 for i in eos_token_id ):
209
- logger . warning ( f"` eos_token_id` has to be a list of positive integers, but is { eos_token_id } " )
208
+ if not isinstance (eos_token_id , torch . Tensor ):
209
+ if isinstance ( eos_token_id , int ):
210
+ eos_token_id = [ eos_token_id ]
211
+ eos_token_id = torch . tensor ( eos_token_id )
210
212
211
213
self .prompt_length_to_skip = prompt_length_to_skip
212
214
self .min_new_tokens = min_new_tokens
@@ -217,8 +219,8 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
217
219
new_tokens_length = input_ids .shape [- 1 ] - self .prompt_length_to_skip
218
220
scores_processed = scores .clone ()
219
221
vocab_tensor = torch .arange (scores .shape [- 1 ], device = scores .device )
220
- eos_token_id = torch . tensor ( self .eos_token_id , device = scores .device )
221
- eos_token_mask = torch .isin (vocab_tensor , eos_token_id )
222
+ self . eos_token_id = self .eos_token_id . to ( scores .device )
223
+ eos_token_mask = torch .isin (vocab_tensor , self . eos_token_id )
222
224
if new_tokens_length < self .min_new_tokens :
223
225
scores_processed = torch .where (eos_token_mask , - math .inf , scores )
224
226
@@ -1195,8 +1197,8 @@ class NoBadWordsLogitsProcessor(SequenceBiasLogitsProcessor):
1195
1197
Args:
1196
1198
bad_words_ids (`List[List[int]]`):
1197
1199
List of list of token ids that are not allowed to be generated.
1198
- eos_token_id (`Union[int, List[int]]` ):
1199
- The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens .
1200
+ eos_token_id (`Union[int, List[int], torch.Tensor]`, *optional* ):
1201
+ The id(s) of the *end-of-sequence* token.
1200
1202
1201
1203
Examples:
1202
1204
@@ -1233,18 +1235,22 @@ class NoBadWordsLogitsProcessor(SequenceBiasLogitsProcessor):
1233
1235
```
1234
1236
"""
1235
1237
1236
- def __init__ (self , bad_words_ids : List [List [int ]], eos_token_id : Union [int , List [int ]]):
1238
+ def __init__ (
1239
+ self , bad_words_ids : List [List [int ]], eos_token_id : Optional [Union [int , List [int ], torch .Tensor ]] = None
1240
+ ):
1237
1241
self .bad_word_ids = bad_words_ids
1238
1242
self ._validate_arguments ()
1239
1243
1240
1244
# Filter EOS token from bad_words_ids
1241
- if eos_token_id is None :
1242
- eos_token_id = []
1243
- if isinstance (eos_token_id , int ):
1244
- eos_token_id = [eos_token_id ]
1245
- bad_words_ids = list (
1246
- filter (lambda bad_token_seq : all (bad_token_seq != [i ] for i in eos_token_id ), bad_words_ids )
1247
- )
1245
+ if eos_token_id is not None :
1246
+ if not isinstance (eos_token_id , torch .Tensor ):
1247
+ if isinstance (eos_token_id , int ):
1248
+ eos_token_id = [eos_token_id ]
1249
+ eos_token_id = torch .tensor (eos_token_id )
1250
+
1251
+ bad_words_ids = list (
1252
+ filter (lambda bad_token_seq : all (bad_token_seq != [i ] for i in eos_token_id ), bad_words_ids )
1253
+ )
1248
1254
1249
1255
# Forbidding a sequence is equivalent to setting its bias to -inf
1250
1256
sequence_bias = {tuple (sequence ): float ("-inf" ) for sequence in bad_words_ids }
@@ -1522,9 +1528,8 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
1522
1528
Args:
1523
1529
max_length (`int`):
1524
1530
The maximum length of the sequence to be generated.
1525
- eos_token_id (`Union[int, List[int]]`):
1526
- The id of the token to force as the last generated token when `max_length` is reached. Optionally, use a
1527
- list to set multiple *end-of-sequence* tokens.
1531
+ eos_token_id (`Union[int, List[int], torch.Tensor]`):
1532
+ The id(s) of the *end-of-sequence* token.
1528
1533
1529
1534
Examples:
1530
1535
@@ -1548,15 +1553,22 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
1548
1553
```
1549
1554
"""
1550
1555
1551
- def __init__ (self , max_length : int , eos_token_id : Union [int , List [int ]]):
1556
+ def __init__ (self , max_length : int , eos_token_id : Union [int , List [int ], torch . Tensor ]):
1552
1557
self .max_length = max_length
1553
- if isinstance (eos_token_id , int ):
1554
- eos_token_id = [eos_token_id ]
1558
+
1559
+ if not isinstance (eos_token_id , torch .Tensor ):
1560
+ if isinstance (eos_token_id , int ):
1561
+ eos_token_id = [eos_token_id ]
1562
+ eos_token_id = torch .tensor (eos_token_id )
1555
1563
self .eos_token_id = eos_token_id
1556
1564
1565
+ if torch .is_floating_point (eos_token_id ) or (eos_token_id < 0 ).any ():
1566
+ raise ValueError (f"`eos_token_id` has to be a list of positive integers, but is { eos_token_id } " )
1567
+
1557
1568
@add_start_docstrings (LOGITS_PROCESSOR_INPUTS_DOCSTRING )
1558
1569
def __call__ (self , input_ids : torch .LongTensor , scores : torch .FloatTensor ) -> torch .FloatTensor :
1559
1570
cur_len = input_ids .shape [- 1 ]
1571
+ self .eos_token_id = self .eos_token_id .to (scores .device )
1560
1572
scores_processed = scores
1561
1573
if cur_len == self .max_length - 1 :
1562
1574
scores_processed = torch .full_like (scores , - math .inf )
@@ -1595,8 +1607,8 @@ class ExponentialDecayLengthPenalty(LogitsProcessor):
1595
1607
exponential_decay_length_penalty (`tuple(int, float)`):
1596
1608
This tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty
1597
1609
starts and `decay_factor` represents the factor of exponential decay
1598
- eos_token_id (`Union[int, List[int]]`):
1599
- The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens .
1610
+ eos_token_id (`Union[int, List[int], torch.Tensor ]`):
1611
+ The id(s) of the *end-of-sequence* token.
1600
1612
input_ids_seq_length (`int`):
1601
1613
The length of the input sequence.
1602
1614
@@ -1656,27 +1668,33 @@ class ExponentialDecayLengthPenalty(LogitsProcessor):
1656
1668
def __init__ (
1657
1669
self ,
1658
1670
exponential_decay_length_penalty : Tuple [int , float ],
1659
- eos_token_id : Union [int , List [int ]],
1671
+ eos_token_id : Union [int , List [int ], torch . Tensor ],
1660
1672
input_ids_seq_length : int ,
1661
1673
):
1662
1674
self .regulation_start = exponential_decay_length_penalty [0 ] + input_ids_seq_length
1663
1675
self .regulation_factor = exponential_decay_length_penalty [1 ]
1664
- if isinstance (eos_token_id , int ):
1665
- eos_token_id = [eos_token_id ]
1676
+
1677
+ if not isinstance (eos_token_id , torch .Tensor ):
1678
+ if isinstance (eos_token_id , int ):
1679
+ eos_token_id = [eos_token_id ]
1680
+ eos_token_id = torch .tensor (eos_token_id )
1666
1681
self .eos_token_id = eos_token_id
1667
1682
1683
+ if torch .is_floating_point (eos_token_id ) or (eos_token_id < 0 ).any ():
1684
+ raise ValueError (f"`eos_token_id` has to be a list of positive integers, but is { eos_token_id } " )
1685
+
1668
1686
@add_start_docstrings (LOGITS_PROCESSOR_INPUTS_DOCSTRING )
1669
1687
def __call__ (self , input_ids : torch .LongTensor , scores : torch .FloatTensor ) -> torch .FloatTensor :
1670
1688
cur_len = input_ids .shape [- 1 ]
1689
+ self .eos_token_id = self .eos_token_id .to (scores .device )
1671
1690
penalties = torch .zeros_like (scores )
1672
1691
scores_processed = scores
1673
1692
if cur_len > self .regulation_start :
1674
- for i in self .eos_token_id :
1675
- penalty_idx = cur_len - self .regulation_start
1676
- # To support negative logits we compute the penalty of the absolute value and add to the original logit
1677
- penalty = torch .abs (scores [:, i ]) * (pow (self .regulation_factor , penalty_idx ) - 1 )
1678
- penalties [:, i ] = penalty
1679
- scores_processed = scores + penalties
1693
+ penalty_idx = cur_len - self .regulation_start
1694
+ # To support negative logits we compute the penalty of the absolute value and add to the original logit
1695
+ penalty = torch .abs (scores [:, self .eos_token_id ]) * (pow (self .regulation_factor , penalty_idx ) - 1 )
1696
+ penalties [:, self .eos_token_id ] = penalty
1697
+ scores_processed = scores + penalties
1680
1698
return scores_processed
1681
1699
1682
1700
@@ -1753,7 +1771,7 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor):
1753
1771
"""
1754
1772
1755
1773
def __init__ (self , begin_suppress_tokens , begin_index ):
1756
- self .begin_suppress_tokens = list (begin_suppress_tokens )
1774
+ self .begin_suppress_tokens = torch . tensor ( list (begin_suppress_tokens ) )
1757
1775
self .begin_index = begin_index
1758
1776
1759
1777
def set_begin_index (self , begin_index ):
@@ -1762,8 +1780,8 @@ def set_begin_index(self, begin_index):
1762
1780
@add_start_docstrings (LOGITS_PROCESSOR_INPUTS_DOCSTRING )
1763
1781
def __call__ (self , input_ids : torch .LongTensor , scores : torch .FloatTensor ) -> torch .FloatTensor :
1764
1782
vocab_tensor = torch .arange (scores .shape [- 1 ], device = scores .device )
1765
- begin_suppress_tokens = torch . tensor ( self .begin_suppress_tokens , device = scores .device )
1766
- suppress_token_mask = torch .isin (vocab_tensor , begin_suppress_tokens )
1783
+ self . begin_suppress_tokens = self .begin_suppress_tokens . to ( scores .device )
1784
+ suppress_token_mask = torch .isin (vocab_tensor , self . begin_suppress_tokens )
1767
1785
scores_processed = scores
1768
1786
if input_ids .shape [- 1 ] == self .begin_index :
1769
1787
scores_processed = torch .where (suppress_token_mask , - float ("inf" ), scores )
@@ -1801,13 +1819,13 @@ class SuppressTokensLogitsProcessor(LogitsProcessor):
1801
1819
"""
1802
1820
1803
1821
def __init__ (self , suppress_tokens ):
1804
- self .suppress_tokens = list (suppress_tokens )
1822
+ self .suppress_tokens = torch . tensor ( list (suppress_tokens ) )
1805
1823
1806
1824
@add_start_docstrings (LOGITS_PROCESSOR_INPUTS_DOCSTRING )
1807
1825
def __call__ (self , input_ids : torch .LongTensor , scores : torch .FloatTensor ) -> torch .FloatTensor :
1808
1826
vocab_tensor = torch .arange (scores .shape [- 1 ], device = scores .device )
1809
- suppress_tokens = torch . tensor ( self .suppress_tokens , device = scores .device )
1810
- suppress_token_mask = torch .isin (vocab_tensor , suppress_tokens )
1827
+ self . suppress_tokens = self .suppress_tokens . to ( scores .device )
1828
+ suppress_token_mask = torch .isin (vocab_tensor , self . suppress_tokens )
1811
1829
scores = torch .where (suppress_token_mask , - float ("inf" ), scores )
1812
1830
return scores
1813
1831
@@ -2268,23 +2286,30 @@ class BarkEosPrioritizerLogitsProcessor(LogitsProcessor):
2268
2286
</Tip>
2269
2287
2270
2288
Args:
2271
- eos_token_id (`Union[int, List[int]]`):
2272
- The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens .
2289
+ eos_token_id (`Union[int, List[int], torch.Tensor ]`):
2290
+ The id(s) of the *end-of-sequence* token.
2273
2291
min_eos_p (`float`, *optional*):
2274
2292
Minimum end of speech threshold.
2275
2293
"""
2276
2294
2277
- def __init__ (self , eos_token_id : Union [int , List [int ]], min_eos_p : float ):
2278
- if isinstance (eos_token_id , int ):
2279
- eos_token_id = [eos_token_id ]
2295
+ def __init__ (self , eos_token_id : Union [int , List [int ], torch .Tensor ], min_eos_p : float ):
2296
+ if not isinstance (eos_token_id , torch .Tensor ):
2297
+ if isinstance (eos_token_id , int ):
2298
+ eos_token_id = [eos_token_id ]
2299
+ eos_token_id = torch .tensor (eos_token_id )
2280
2300
self .eos_token_id = eos_token_id
2301
+
2302
+ if torch .is_floating_point (eos_token_id ) or (eos_token_id < 0 ).any ():
2303
+ raise ValueError (f"`eos_token_id` has to be a list of positive integers, but is { eos_token_id } " )
2304
+
2281
2305
if min_eos_p is not None and min_eos_p <= 0 :
2282
2306
raise ValueError (f"`min_eos_p` has to be a positive float, but is { min_eos_p } " )
2283
2307
self .min_eos_p = min_eos_p
2284
2308
2285
2309
@add_start_docstrings (LOGITS_PROCESSOR_INPUTS_DOCSTRING )
2286
2310
def __call__ (self , input_ids : torch .LongTensor , scores : torch .FloatTensor ) -> torch .FloatTensor :
2287
2311
scores_processed = scores
2312
+ self .eos_token_id = self .eos_token_id .to (scores .device )
2288
2313
if self .min_eos_p :
2289
2314
probs = torch .nn .functional .softmax (scores .float (), dim = - 1 )
2290
2315
# create scores full of -inf except for the eos_token_id
0 commit comments