11# SPDX-License-Identifier: Apache-2.0
22"""Compare the with and without prefix caching."""
33
4+ from typing import Optional
5+
46import pytest
57
68from vllm .multimodal .inputs import MultiModalKwargs , PlaceholderRange
1517def make_request (request_id ,
1618 prompt_token_ids ,
1719 mm_positions = None ,
18- mm_hashes = None ):
20+ mm_hashes = None ,
21+ prompt_logprobs : Optional [int ] = None ):
1922 if mm_positions is None :
2023 multi_modal_inputs = None
2124 else :
@@ -28,7 +31,8 @@ def make_request(request_id,
2831 multi_modal_inputs = multi_modal_inputs ,
2932 multi_modal_hashes = mm_hashes ,
3033 multi_modal_placeholders = mm_positions ,
31- sampling_params = SamplingParams (max_tokens = 17 ),
34+ sampling_params = SamplingParams (max_tokens = 17 ,
35+ prompt_logprobs = prompt_logprobs ),
3236 eos_token_id = 100 ,
3337 arrival_time = 0 ,
3438 lora_request = None ,
@@ -144,6 +148,110 @@ def test_prefill():
144148 assert manager .block_pool .free_block_queue .free_list_tail is None
145149
146150
151+ def test_prefill_plp ():
152+ '''Test prefill with APC and some prompt logprobs (plp) requests.
153+
154+ 1. Schedule plp request and validate APC block allocation
155+ 2. Schedule non-plp request and validate blocks
156+ 3. Schedule plp request; no hit should occur; validate blocks
157+ '''
158+ manager = KVCacheManager (
159+ block_size = 16 ,
160+ num_gpu_blocks = 10 ,
161+ max_model_len = 8192 ,
162+ sliding_window = None ,
163+ enable_caching = True ,
164+ num_preallocate_tokens = 16 ,
165+ )
166+
167+ # Complete 3 blocks (48 tokens)
168+ common_token_ids = [i for i in range (3 ) for _ in range (16 )]
169+
170+ # Request #0 is a prompt logprobs request
171+ # Fully cache miss
172+ # Incomplete 1 block (7 tokens)
173+ unique_token_ids = [3 ] * 7
174+ all_token_ids = common_token_ids + unique_token_ids
175+ req0 = make_request ("0" , all_token_ids , prompt_logprobs = 5 )
176+ computed_blocks , num_computed_tokens = manager .get_computed_blocks (req0 )
177+ assert len (manager .req_to_block_hashes [req0 .request_id ]) == 3
178+ assert not computed_blocks
179+ assert num_computed_tokens == 0
180+ blocks = manager .allocate_slots (req0 , 55 , computed_blocks )
181+ assert [b .block_id for b in blocks ] == [0 , 1 , 2 , 3 , 4 ]
182+ req0_block_hashes = [b .block_hash for b in blocks ]
183+
184+ # Check full block metadata
185+ parent_block_hash = None
186+ for block_id in (0 , 1 , 2 ):
187+ block_tokens = tuple (all_token_ids [block_id * 16 :(block_id + 1 ) * 16 ])
188+ block_hash = hash_block_tokens (parent_block_hash , block_tokens )
189+ assert manager .block_pool .blocks [block_id ].block_hash == block_hash
190+ assert manager .block_pool .blocks [block_id ].ref_cnt == 1
191+ parent_block_hash = block_hash .hash_value
192+
193+ # Check partial/preallocated block metadata
194+ for block_id in (3 , 4 ):
195+ assert manager .block_pool .blocks [block_id ].block_hash is None
196+ assert manager .block_pool .blocks [block_id ].ref_cnt == 1
197+
198+ # Request #1 is a non-prompt-logprobs request:
199+ # Cache hit in the common prefix when the original block is still in use.
200+ # Incomplete 1 block (5 tokens)
201+ unique_token_ids = [3 ] * 5
202+ req1 = make_request ("1" , common_token_ids + unique_token_ids )
203+ computed_blocks , num_computed_tokens = manager .get_computed_blocks (req1 )
204+ assert len (manager .req_to_block_hashes [req1 .request_id ]) == 3
205+ assert [b .block_id for b in computed_blocks ] == [0 , 1 , 2 ]
206+ assert num_computed_tokens == 3 * 16
207+ num_new_tokens = 53 - 3 * 16
208+ blocks = manager .allocate_slots (req1 , num_new_tokens , computed_blocks )
209+ assert [b .block_id for b in blocks ] == [5 , 6 ]
210+ for block in computed_blocks :
211+ assert block .ref_cnt == 2
212+
213+ # At this point, we should have 3 free blocks left.
214+ assert manager .block_pool .free_block_queue .num_free_blocks == 3
215+
216+ manager .free (req0 )
217+ manager .free (req1 )
218+
219+ # All blocks should be available.
220+ assert manager .block_pool .free_block_queue .num_free_blocks == 10
221+ # The order should be
222+ # [unallocated (7, 8, 9)]
223+ # [unique_req0 (4, 3)]
224+ # [unique_req1 (6, 5)]
225+ # [common (2, 1, 0)]
226+ assert [
227+ b .block_id
228+ for b in manager .block_pool .free_block_queue .get_all_free_blocks ()
229+ ] == [7 , 8 , 9 , 4 , 3 , 6 , 5 , 2 , 1 , 0 ]
230+
231+ # Request #2 is a prompt-logprobs request:
232+ # NO cache hit in the common prefix; duplicates request #0 cached blocks
233+ unique_token_ids = [3 ] * 6
234+ req2 = make_request ("2" ,
235+ common_token_ids + unique_token_ids ,
236+ prompt_logprobs = 5 )
237+ computed_blocks , num_computed_tokens = manager .get_computed_blocks (req2 )
238+ assert len (manager .req_to_block_hashes [req2 .request_id ]) == 3
239+ assert not computed_blocks
240+ assert num_computed_tokens == 0
241+ blocks = manager .allocate_slots (req2 , 55 , computed_blocks )
242+ block_ids = [b .block_id for b in blocks ]
243+ # Duplicate cached blocks have different ids but same hashes vs request #0
244+ assert [b .block_hash for b in blocks ] == req0_block_hashes
245+ assert block_ids != [0 , 1 , 2 , 3 , 4 ]
246+
247+ # Request #2 block hashes are valid since request #0 hashes are.
248+ # Check block reference counts.
249+ for block_id in block_ids :
250+ assert manager .block_pool .blocks [block_id ].ref_cnt == 1
251+
252+ manager .free (req2 )
253+
254+
147255def test_decode ():
148256 manager = KVCacheManager (
149257 block_size = 16 ,
0 commit comments