@@ -77,6 +77,49 @@ def _create_allowed_token_ids(
7777 return mask
7878
7979
80+ def _create_bad_words_token_ids (
81+ batch_size : int , vocab_size : int ,
82+ bad_words_lengths : list [tuple [int ]]) -> dict [int , list [list [int ]]]:
83+ bad_words_token_ids = {}
84+ for batch_idx in range (batch_size ):
85+ token_ids_single_batch = []
86+ for bad_words_length in bad_words_lengths :
87+ token_ids = np .random .choice (vocab_size ,
88+ size = bad_words_length ,
89+ replace = True ).tolist ()
90+ token_ids_single_batch .append (token_ids )
91+ bad_words_token_ids [batch_idx ] = token_ids_single_batch
92+ if batch_size >= 2 :
93+ # Test no bad_words for some batch
94+ no_bad_words_batch_idx = np .random .choice (batch_size )
95+ bad_words_token_ids .pop (no_bad_words_batch_idx , None )
96+ return bad_words_token_ids
97+
98+
99+ def _update_output_token_ids_for_bad_words (
100+ metadata : SamplingMetadata , vocab_size : int ) -> dict [int , list [int ]]:
101+ bad_words_last_tokens = {}
102+ for batch_idx , bad_words_token_ids in metadata .bad_words_token_ids .items ():
103+ output_token_ids = metadata .output_token_ids [batch_idx ]
104+ bad_words_last_token : list [int ] = []
105+ for i , bad_word_token_ids in enumerate (bad_words_token_ids ):
106+ if len (bad_word_token_ids ) == 1 :
107+ # Single token id always affects logits
108+ bad_words_last_token .append (bad_word_token_ids [0 ])
109+ else :
110+ prefix_length = len (bad_word_token_ids ) - 1
111+ has_bad_words = np .random .choice ([True , False ])
112+ if has_bad_words :
113+ output_token_ids [- prefix_length :] = bad_word_token_ids [:- 1 ]
114+ bad_words_last_token .append (bad_word_token_ids [- 1 ])
115+ break # Maximum one update to output_token_ids
116+ else : # Make sure no accidental match to bad words
117+ output_token_ids [- 1 ] = (bad_word_token_ids [- 2 ] +
118+ 1 ) % vocab_size
119+ bad_words_last_tokens [batch_idx ] = bad_words_last_token
120+ return bad_words_last_tokens
121+
122+
80123def _create_default_sampling_metadata (
81124 num_output_tokens : int ,
82125 batch_size : int ,
@@ -112,6 +155,7 @@ def _create_default_sampling_metadata(
112155 min_tokens = {},
113156 logit_bias = [None ] * batch_size ,
114157 allowed_token_ids_mask = None ,
158+ bad_words_token_ids = {},
115159 )
116160 return fake_sampling_metadata
117161
@@ -467,3 +511,35 @@ def test_sampler_allowed_token_ids(device: str, batch_size: int,
467511 "inf" ), f"{ batch_idx } , { token_id } "
468512 else :
469513 assert logits_for_req [token_id ] != - float ("inf" )
514+
515+
516+ @pytest .mark .parametrize ("device" , CUDA_DEVICES )
517+ @pytest .mark .parametrize ("batch_size" , [1 , 2 , 32 ])
518+ @pytest .mark .parametrize ("bad_words_lengths" , [(1 , ), (1 , 3 ), (2 , 2 )])
519+ def test_sampler_bad_words (device : str , batch_size : int ,
520+ bad_words_lengths : list [tuple [int ]]):
521+ """
522+ Test to verify that when the bad words restriction is present, tokens
523+ are penalized based on their match with the bad words.
524+ """
525+ torch .set_default_device (device )
526+ # Create fake logits where each token is assigned the same
527+ # logit value.
528+ fake_logits = _create_fake_logits (batch_size , VOCAB_SIZE )
529+ sampling_metadata = _create_default_sampling_metadata (
530+ NUM_OUTPUT_TOKENS , batch_size , VOCAB_SIZE , torch .device (device ))
531+ sampling_metadata .bad_words_token_ids = _create_bad_words_token_ids (
532+ batch_size , VOCAB_SIZE , bad_words_lengths )
533+ bad_words_last_tokens = _update_output_token_ids_for_bad_words (
534+ sampling_metadata , VOCAB_SIZE )
535+ sampler = Sampler ()
536+ logits = sampler .apply_bad_words (fake_logits , sampling_metadata )
537+ logits = logits .cpu ()
538+ for batch_idx in range (batch_size ):
539+ logits_for_req = logits [batch_idx ]
540+ for token_id in range (VOCAB_SIZE ):
541+ if (batch_idx in bad_words_last_tokens
542+ and token_id in bad_words_last_tokens [batch_idx ]):
543+ assert logits_for_req [token_id ] == - float ("inf" )
544+ else :
545+ assert logits_for_req [token_id ] != - float ("inf" )
0 commit comments