99import torch
1010
1111import vllm .v1 .core .kv_cache_utils as kv_cache_utils
12- from vllm .distributed .kv_events import AllBlocksCleared , BlockRemoved
12+ from vllm .distributed .kv_events import AllBlocksCleared , BlockRemoved , BlockStored
13+ from vllm .lora .request import LoRARequest
1314from vllm .multimodal .inputs import (
1415 MultiModalFeatureSpec ,
1516 MultiModalKwargsItem ,
@@ -59,6 +60,7 @@ def make_request(
5960 mm_hashes : list [str ] | None = None ,
6061 prompt_logprobs : int | None = None ,
6162 cache_salt : str | None = None ,
63+ lora_request : LoRARequest | None = None ,
6264):
6365 mm_features = []
6466 if mm_positions is not None :
@@ -79,7 +81,7 @@ def make_request(
7981 sampling_params = SamplingParams (max_tokens = 17 , prompt_logprobs = prompt_logprobs ),
8082 pooling_params = None ,
8183 eos_token_id = 100 ,
82- lora_request = None ,
84+ lora_request = lora_request ,
8385 cache_salt = cache_salt ,
8486 block_hasher = get_request_block_hasher (block_size , hash_fn ),
8587 )
@@ -1337,6 +1339,63 @@ def test_kv_cache_events(blocks_to_cache: int):
13371339 assert len (manager .block_pool .cached_block_hash_to_block ) == 0
13381340
13391341
1342+ @pytest .mark .parametrize ("blocks_to_cache" , [2 , 3 , 10 ])
1343+ def test_kv_cache_events_with_lora (blocks_to_cache : int ):
1344+ """Test BlockStored events contain correct lora_id when using LoRA requests."""
1345+ block_size = 16
1346+ num_blocks = blocks_to_cache + 1
1347+
1348+ # Create KVCacheManager with events enabled
1349+ manager = KVCacheManager (
1350+ make_kv_cache_config (block_size , num_blocks ),
1351+ max_model_len = 8192 ,
1352+ enable_caching = True ,
1353+ enable_kv_cache_events = True ,
1354+ )
1355+
1356+ # Test with LoRA request
1357+ lora_request = LoRARequest (
1358+ lora_name = "test_lora" , lora_int_id = 42 , lora_path = "/test/path"
1359+ )
1360+
1361+ num_tokens = block_size * blocks_to_cache
1362+ req_with_lora = make_request (
1363+ "lora_req" ,
1364+ list (range (num_tokens )),
1365+ block_size ,
1366+ sha256 ,
1367+ lora_request = lora_request ,
1368+ )
1369+
1370+ # Allocate slots and get events
1371+ _ = manager .allocate_slots (req_with_lora , num_tokens )
1372+ events = manager .take_events ()
1373+
1374+ # Verify BlockStored event contains correct lora_id
1375+ block_stored_event = events [- 1 ]
1376+ assert isinstance (block_stored_event , BlockStored )
1377+ assert block_stored_event .lora_id == 42 # Should match lora_request.adapter_id
1378+ assert len (block_stored_event .block_hashes ) == blocks_to_cache
1379+ assert block_stored_event .block_size == block_size
1380+
1381+ # Clean up
1382+ manager .free (req_with_lora )
1383+
1384+ # Test without LoRA request (should have lora_id=None)
1385+ req_without_lora = make_request (
1386+ "no_lora_req" , list (range (num_tokens )), block_size , sha256
1387+ )
1388+
1389+ _ = manager .allocate_slots (req_without_lora , num_tokens )
1390+ events = manager .take_events ()
1391+
1392+ block_stored_event = events [- 1 ]
1393+ assert isinstance (block_stored_event , BlockStored )
1394+ assert block_stored_event .lora_id is None # Should be None when no LoRA request
1395+ assert len (block_stored_event .block_hashes ) == blocks_to_cache
1396+ assert block_stored_event .block_size == block_size
1397+
1398+
13401399def test_eagle_enabled_removes_last_block ():
13411400 """Verify Eagle does NOT remove blocks when request
13421401 length is divisible by block size."""
0 commit comments