44from typing import Any , Optional
55
66import pytest
7- from transformers import AutoTokenizer
7+ from transformers import (AutoTokenizer , PreTrainedTokenizer ,
8+ PreTrainedTokenizerFast )
89
910from vllm .inputs import token_inputs
1011from vllm .sequence import Logprob , SamplingParams , Sequence , SequenceGroup
11- from vllm .transformers_utils .detokenizer import (Detokenizer ,
12- detokenize_incrementally )
12+ from vllm .transformers_utils .detokenizer import Detokenizer
1313from vllm .transformers_utils .tokenizer_group import get_tokenizer_group
1414from vllm .transformers_utils .tokenizers .mistral import MistralTokenizer
15+ from vllm .v1 .engine import EngineCoreRequest
16+ from vllm .v1 .engine .detokenizer import (FastIncrementalDetokenizer ,
17+ IncrementalDetokenizer ,
18+ SlowIncrementalDetokenizer )
19+
20+ SPECIAL_TOKS_TRUTH = [
21+ "Some text with adjacent special tokens <|padding|><|padding|><fim_prefix><fim_middle><fim_suffix>other text<fim_pad>" , # noqa
22+ ]
1523
1624TRUTH = [
1725 "Hello here, this is a simple test" ,
2230 # incomplete UTF-8 characters
2331 # see https://github.com/vllm-project/vllm/pull/9625
2432 "ပုံပြင်လေးပြောပြပါ်" ,
25- ]
33+ ] + SPECIAL_TOKS_TRUTH
34+
2635TOKENIZERS = [
2736 "facebook/opt-125m" ,
2837 "gpt2" ,
3847]
3948
4049
41- def _run_incremental_decode (tokenizer , all_input_ids ,
42- skip_special_tokens : bool , starting_index : int ):
43- decoded_text = ""
44- offset = 0
45- token_offset = 0
46- prev_tokens = None
47- for i in range (starting_index , len (all_input_ids )):
48- new_tokens , text , offset , token_offset = detokenize_incrementally (
49- tokenizer ,
50- all_input_ids [:i + 1 ],
51- prev_tokens ,
52- offset ,
53- token_offset ,
54- skip_special_tokens = skip_special_tokens )
55- decoded_text += text
56- if prev_tokens is None :
57- prev_tokens = new_tokens
58- else :
59- prev_tokens += new_tokens
60- return decoded_text
50+ def _run_incremental_decode (tokenizer ,
51+ all_input_ids ,
52+ skip_special_tokens : bool ,
53+ starting_index : int ,
54+ spaces_between_special_tokens : bool = True ,
55+ fast : Optional [bool ] = None ):
56+
57+ prompt_token_ids = all_input_ids [:starting_index ]
58+
59+ params = SamplingParams (
60+ skip_special_tokens = skip_special_tokens ,
61+ spaces_between_special_tokens = spaces_between_special_tokens ,
62+ )
63+ request = EngineCoreRequest ("" , "" , prompt_token_ids , None , None , None ,
64+ params , None , 0.0 , None )
65+
66+ if fast is None :
67+ detokenizer = IncrementalDetokenizer .from_new_request (
68+ tokenizer , request )
69+ elif fast :
70+ detokenizer = FastIncrementalDetokenizer (tokenizer , request )
71+ else :
72+ detokenizer = SlowIncrementalDetokenizer (tokenizer , request )
73+
74+ output_text = ""
75+ for i , token_id in enumerate (all_input_ids [starting_index :]):
76+ detokenizer .update ([token_id ], False )
77+ finished = i == len (all_input_ids ) - 1
78+ output_text += detokenizer .get_next_output_text (finished , delta = True )
79+
80+ return output_text , detokenizer .output_token_ids
6181
6282
6383@pytest .fixture
@@ -85,11 +105,13 @@ def test_mistral_edge_case(tokenizer, truth):
85105 starting_index = 0
86106 all_input_ids = tokenizer (truth , add_special_tokens = False ).input_ids
87107
88- decoded_text = _run_incremental_decode (tokenizer ,
89- all_input_ids ,
90- skip_special_tokens = True ,
91- starting_index = starting_index )
108+ decoded_text , out_ids = _run_incremental_decode (
109+ tokenizer ,
110+ all_input_ids ,
111+ skip_special_tokens = True ,
112+ starting_index = starting_index )
92113 assert decoded_text == truth
114+ assert out_ids == all_input_ids [starting_index :]
93115
94116
95117@pytest .fixture
@@ -106,40 +128,86 @@ def skip_special_tokens(request, tokenizer_name) -> Generator[bool, Any, None]:
106128@pytest .mark .parametrize ("with_prompt" , [True , False ])
107129@pytest .mark .parametrize ("tokenizer_name" , TOKENIZERS )
108130@pytest .mark .parametrize ("skip_special_tokens" , (True , False ), indirect = True )
109- def test_decode_streaming (tokenizer , truth , with_prompt , skip_special_tokens ):
131+ @pytest .mark .parametrize ("spaces_between_special_tokens" , (True , False ))
132+ @pytest .mark .parametrize ("fast" , (True , False ))
133+ def test_decode_streaming (tokenizer , truth , with_prompt , skip_special_tokens ,
134+ spaces_between_special_tokens , fast ):
135+ if fast and not isinstance (tokenizer , PreTrainedTokenizerFast ):
136+ pytest .skip ()
137+
138+ if skip_special_tokens and not spaces_between_special_tokens :
139+ pytest .skip ()
140+
141+ if not fast and isinstance (tokenizer , PreTrainedTokenizerFast ):
142+ # Fix up inconsistency in fast/slow tokenizer behaviour.
143+ tokenizer .add_special_tokens ({
144+ "additional_special_tokens" : [
145+ at for at in
146+ tokenizer ._tokenizer .get_added_tokens_decoder ().values ()
147+ if at .special
148+ ]
149+ })
150+
151+ extra_decode_args = {} if not isinstance (tokenizer , PreTrainedTokenizer ) \
152+ else {"spaces_between_special_tokens" : spaces_between_special_tokens }
153+
154+ truth_tokens = tokenizer (truth , add_special_tokens = False ).input_ids
155+ if tokenizer .bos_token_id is not None :
156+ truth_tokens .insert (0 , tokenizer .bos_token_id )
157+ truth_tokens .append (tokenizer .eos_token_id )
158+
159+ new_truth = tokenizer .decode (truth_tokens ,
160+ skip_special_tokens = skip_special_tokens ,
161+ ** extra_decode_args )
162+
110163 if with_prompt :
111- truth_tokens = tokenizer (truth , add_special_tokens = False ).input_ids
112- prompt_input_ids = truth_tokens [:len (truth ) // 2 ]
113- generated_input_ids = truth_tokens [len (truth ) // 2 :]
164+ num_prompt_tokens = len (
165+ tokenizer (truth [:len (truth ) // 2 ],
166+ add_special_tokens = False ).input_ids )
167+ if tokenizer .bos_token_id is not None :
168+ num_prompt_tokens += 1
169+
170+ prompt_input_ids = truth_tokens [:num_prompt_tokens ]
171+ generated_input_ids = truth_tokens [num_prompt_tokens :]
114172 all_input_ids = prompt_input_ids + generated_input_ids
115173 starting_index = len (prompt_input_ids )
116174 prompt = tokenizer .decode (prompt_input_ids ,
117- skip_special_tokens = skip_special_tokens )
118- generated = truth [len (prompt ):]
175+ skip_special_tokens = skip_special_tokens ,
176+ ** extra_decode_args )
177+
178+ generated = new_truth [len (prompt ):]
119179 else :
120- generated = truth
180+ generated = new_truth
121181 starting_index = 0
122- all_input_ids = tokenizer (truth , add_special_tokens = False ).input_ids
123- if skip_special_tokens :
124- if tokenizer .bos_token_id is not None :
125- all_input_ids = [tokenizer .bos_token_id ] + all_input_ids
126- starting_index += 1
127- all_input_ids = all_input_ids + [tokenizer .eos_token_id ]
182+ all_input_ids = truth_tokens
128183
129- decoded_text = _run_incremental_decode (
184+ decoded_text , out_ids = _run_incremental_decode (
130185 tokenizer ,
131186 all_input_ids ,
132187 skip_special_tokens = skip_special_tokens ,
133- starting_index = starting_index )
188+ starting_index = starting_index ,
189+ spaces_between_special_tokens = spaces_between_special_tokens ,
190+ fast = fast )
134191
135192 assert decoded_text == generated
193+ assert out_ids == all_input_ids [starting_index :]
136194
137- decoded_text = _run_incremental_decode (
195+
196+ @pytest .mark .parametrize ("tokenizer_name" , TOKENIZERS )
197+ @pytest .mark .parametrize ("fast" , (True , False ))
198+ def test_oov_decode (tokenizer , fast ):
199+ if fast and not isinstance (tokenizer , PreTrainedTokenizerFast ):
200+ pytest .skip ()
201+
202+ decoded_text , out_ids = _run_incremental_decode (
138203 tokenizer , [len (tokenizer )],
139- skip_special_tokens = skip_special_tokens ,
140- starting_index = starting_index )
204+ skip_special_tokens = True ,
205+ starting_index = 0 ,
206+ spaces_between_special_tokens = True ,
207+ fast = fast )
141208
142209 assert decoded_text == ''
210+ assert out_ids == [len (tokenizer )]
143211
144212
145213@pytest .fixture
@@ -165,15 +233,14 @@ def detokenizer(tokenizer_name: str) -> Detokenizer:
165233@pytest .fixture (name = "complete_sequence_token_ids" )
166234def create_complete_sequence_token_ids (complete_sequence : str ,
167235 tokenizer ) -> list [int ]:
168- complete_sequence_token_ids = tokenizer (complete_sequence ).input_ids
169- return complete_sequence_token_ids
236+ return tokenizer (complete_sequence , add_special_tokens = False ).input_ids
170237
171238
172239def create_sequence (prompt_token_ids = None ):
173- prompt_token_ids = prompt_token_ids or [1 ]
240+ prompt_token_ids = prompt_token_ids or []
174241 return Sequence (
175242 seq_id = 0 ,
176- inputs = token_inputs (prompt_token_ids , prompt = "<s>" ),
243+ inputs = token_inputs (prompt_token_ids ),
177244 block_size = 16 ,
178245 )
179246
@@ -224,7 +291,7 @@ def test_decode_sequence_logprobs(complete_sequence: str,
224291 assert sequential_result == "" .join (sequential_logprobs_text_chosen_token )
225292 assert sequential_result != "" .join (sequential_logprobs_text_other_token )
226293
227- if skip_special_tokens :
294+ if not skip_special_tokens :
228295 # Text for logprobs for the chosen token should be the same as the
229296 # generated text. Note that this will only be true if we skip
230297 # special tokens.
@@ -233,10 +300,23 @@ def test_decode_sequence_logprobs(complete_sequence: str,
233300
234301@pytest .mark .parametrize ("complete_sequence" , TRUTH )
235302@pytest .mark .parametrize ("tokenizer_name" , TOKENIZERS )
236- def test_decode_prompt_logprobs (complete_sequence_token_ids : list [int ],
303+ def test_decode_prompt_logprobs (complete_sequence : str ,
304+ complete_sequence_token_ids : list [int ],
237305 detokenizer : Detokenizer ):
306+
307+ # We want to use skip_special_tokens=False here but Mistral tokenizers
308+ # don't support that.
309+ if complete_sequence not in SPECIAL_TOKS_TRUTH :
310+ skip_special_tokens = True
311+ elif not isinstance (detokenizer .tokenizer_group .get_lora_tokenizer (None ),
312+ MistralTokenizer ):
313+ skip_special_tokens = False
314+ else :
315+ pytest .skip ("MistralTokenizers don't support "
316+ "skip_special_tokens=False" )
317+ return
238318 """Verify Detokenizer decodes prompt logprobs correctly."""
239- sampling_params = SamplingParams (skip_special_tokens = True ,
319+ sampling_params = SamplingParams (skip_special_tokens = skip_special_tokens ,
240320 prompt_logprobs = 1 )
241321
242322 # Run sequentially.
@@ -256,8 +336,10 @@ def test_decode_prompt_logprobs(complete_sequence_token_ids: list[int],
256336 # decoded_prompt_logprobs doesn't contain the first token.
257337 token_ids = complete_sequence_token_ids
258338 tokenizer = detokenizer .get_tokenizer_for_seq (seq )
259- text_full = tokenizer .decode (token_ids , skip_special_tokens = True )
260- text_first = tokenizer .decode (token_ids [0 ], skip_special_tokens = True )
339+ text_full = tokenizer .decode (token_ids ,
340+ skip_special_tokens = skip_special_tokens )
341+ text_first = tokenizer .decode (token_ids [0 ],
342+ skip_special_tokens = skip_special_tokens )
261343 text = text_full [len (text_first ):]
262344
263345 # Text for logprobs for the chosen token should be the same as the
0 commit comments