@@ -57,6 +57,26 @@ def _create_logit_bias(
5757 return res
5858
5959
60+ def _create_allowed_token_ids (
61+ batch_size : int ,
62+ vocab_size : int ,
63+ num_allowed_token_ids : int ,
64+ device : torch .device ,
65+ ) -> Optional [torch .Tensor ]:
66+ mask : Optional [torch .Tensor ] = None
67+ for i in range (batch_size ):
68+ if i % 2 == 1 :
69+ continue
70+ if mask is None :
71+ mask = torch .zeros ((batch_size , vocab_size ),
72+ dtype = torch .bool ,
73+ device = device )
74+ start = min (i , vocab_size - 1 )
75+ end = min (i + num_allowed_token_ids , vocab_size - 1 )
76+ mask [i , start :end ] = True
77+ return mask
78+
79+
6080def _create_default_sampling_metadata (
6181 num_output_tokens : int ,
6282 batch_size : int ,
@@ -92,6 +112,7 @@ def _create_default_sampling_metadata(
92112 no_penalties = True ,
93113 min_tokens = {},
94114 logit_bias = [None ] * batch_size ,
115+ allowed_token_ids_mask = None ,
95116 )
96117 return fake_sampling_metadata
97118
@@ -253,7 +274,10 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
253274 sampling_metadata .frequency_penalties = _create_penalty_tensor (
254275 batch_size , frequency_penalty , torch .device (device ))
255276 output_token_ids , sorted_token_ids_in_output = \
256- _create_weighted_output_token_list (batch_size , VOCAB_SIZE )
277+ _create_weighted_output_token_list (
278+ batch_size ,
279+ VOCAB_SIZE ,
280+ )
257281 sampling_metadata .output_token_ids = output_token_ids
258282 sampling_metadata .no_penalties = False
259283 sampler = Sampler ()
@@ -262,8 +286,8 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
262286 for batch_idx in range (batch_size ):
263287 non_penalized_token_id = logits [batch_idx ].argmax ().item ()
264288 penalized_token_id = logits [batch_idx ].argmin ().item ()
265- distinct_sorted_token_ids_in_output = \
266- sorted_token_ids_in_output [ batch_idx ]
289+ distinct_sorted_token_ids_in_output = sorted_token_ids_in_output [
290+ batch_idx ]
267291 most_frequent_token_id = distinct_sorted_token_ids_in_output [
268292 len (distinct_sorted_token_ids_in_output ) - 1 ]
269293 if frequency_penalty > 0 :
@@ -272,8 +296,8 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
272296 # non-penalized token ID is not present in the output, while the
273297 # most penalized token is the one that occurs most frequently in
274298 # the output.
275- assert non_penalized_token_id \
276- not in distinct_sorted_token_ids_in_output
299+ assert ( non_penalized_token_id
300+ not in distinct_sorted_token_ids_in_output )
277301 assert penalized_token_id == most_frequent_token_id
278302 elif frequency_penalty < 0 :
279303 # If `frequency_penalty` is set to < 0, it indicates
@@ -282,8 +306,7 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
282306 # in the output, while the penalized token ID is one that has not
283307 # yet appeared.
284308 assert non_penalized_token_id == most_frequent_token_id
285- assert penalized_token_id \
286- not in distinct_sorted_token_ids_in_output
309+ assert penalized_token_id not in distinct_sorted_token_ids_in_output
287310
288311
289312@pytest .mark .parametrize ("device" , CUDA_DEVICES )
@@ -318,18 +341,18 @@ def test_sampler_repetition_penalty(device: str, batch_size: int,
318341 # If `repetition_penalty` > 1.0, verify that the non-penalized
319342 # token ID has not been seen before, while the penalized token ID
320343 # exists either in the prompt or the output.
321- assert (non_penalized_token_id not in prompt_tokens and \
322- non_penalized_token_id not in output_tokens )
323- assert (penalized_token_id in prompt_tokens or \
324- penalized_token_id in output_tokens )
344+ assert (non_penalized_token_id not in prompt_tokens
345+ and non_penalized_token_id not in output_tokens )
346+ assert (penalized_token_id in prompt_tokens
347+ or penalized_token_id in output_tokens )
325348 elif repetition_penalty < 1.0 :
326349 # If `repetition_penalty` < 1.0, verify that the penalized
327350 # token ID has not been seen before, while the non-penalized
328351 # token ID exists either in the prompt or the output.
329- assert (penalized_token_id not in prompt_tokens and \
330- penalized_token_id not in output_tokens )
331- assert (non_penalized_token_id in prompt_tokens or \
332- non_penalized_token_id in output_tokens )
352+ assert (penalized_token_id not in prompt_tokens
353+ and penalized_token_id not in output_tokens )
354+ assert (non_penalized_token_id in prompt_tokens
355+ or non_penalized_token_id in output_tokens )
333356
334357
335358@pytest .mark .parametrize ("device" , CUDA_DEVICES )
@@ -404,3 +427,44 @@ def test_sampler_logit_bias(device: str, batch_size: int, bias_value: float):
404427 1e-2 )
405428 else :
406429 assert logits_for_req [token_id ] == pytest .approx (1e-2 )
430+
431+
432+ @pytest .mark .parametrize ("device" , CUDA_DEVICES )
433+ @pytest .mark .parametrize ("batch_size" , [1 , 2 , 32 ])
434+ @pytest .mark .parametrize ("num_allowed_token_ids" , [0 , 1 , 2 ])
435+ def test_sampler_allowed_token_ids (device : str , batch_size : int ,
436+ num_allowed_token_ids : int ):
437+ """
438+ Test to verify that when the repetition penalty is enabled, tokens
439+ are penalized based on their presence in the prompt or the existing
440+ output.
441+ """
442+ torch .set_default_device (device )
443+ # Create fake logits where each token is assigned the same
444+ # logit value.
445+ fake_logits = _create_fake_logits (batch_size , VOCAB_SIZE )
446+ sampling_metadata = _create_default_sampling_metadata (
447+ NUM_OUTPUT_TOKENS , batch_size , VOCAB_SIZE , torch .device (device ))
448+ mask = _create_allowed_token_ids (
449+ batch_size = batch_size ,
450+ vocab_size = VOCAB_SIZE ,
451+ num_allowed_token_ids = num_allowed_token_ids ,
452+ device = device ,
453+ )
454+ sampling_metadata .allowed_token_ids_mask = mask
455+ sampler = Sampler ()
456+ logits = sampler .apply_allowed_token_ids (fake_logits , sampling_metadata )
457+ logits = logits .cpu ()
458+ for batch_idx in range (batch_size ):
459+ logits_for_req = logits [batch_idx ]
460+ if batch_idx % 2 == 1 :
461+ assert torch .all (logits_for_req != - float ("inf" ))
462+ continue
463+ for token_id in range (VOCAB_SIZE ):
464+ start = min (batch_idx , VOCAB_SIZE - 1 )
465+ end = min (batch_idx + num_allowed_token_ids , VOCAB_SIZE - 1 )
466+ if token_id >= start and token_id < end :
467+ assert logits_for_req [token_id ] == - float (
468+ "inf" ), f"{ batch_idx } , { token_id } "
469+ else :
470+ assert logits_for_req [token_id ] != - float ("inf" )
0 commit comments