11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3- import unittest .mock as mock
43
54import pytest
65
1716 TPUModelRunner , _get_padded_num_reqs_with_upper_limit ,
1817 _get_padded_token_len , _get_req_paddings , _get_token_paddings )
1918
20- # Mock torch_xla module since it may not be available in the test environments
21- torch_xla_patcher = mock .patch .dict (
22- "sys.modules" , {
23- "torch_xla" : mock .MagicMock (),
24- "torch_xla.core.xla_model" : mock .MagicMock (),
25- "torch_xla.runtime" : mock .MagicMock (),
26- })
27- torch_xla_patcher .start ()
2819
29- # Mock the PallasAttentionBackend
30- pallas_attention_backend_patcher = mock .patch (
31- "vllm.v1.worker.tpu_model_runner.PallasAttentionBackend" , )
32- pallas_attention_backend_patcher .start ()
33-
34-
35- @pytest .fixture
36- def model_runner ():
37- # Patchers have already been started at module level.
20+ def get_vllm_config ():
3821 scheduler_config = SchedulerConfig (
3922 max_num_seqs = 10 ,
4023 max_num_batched_tokens = 512 ,
@@ -60,18 +43,19 @@ def model_runner():
6043 cache_config = cache_config ,
6144 scheduler_config = scheduler_config ,
6245 )
46+ return vllm_config
47+
48+
49+ def get_model_runner (vllm_config ):
6350 device = "xla:0" # Mocking TPU device
64- with mock .patch ("vllm.v1.worker.tpu_model_runner.torch" ), \
65- mock .patch ("vllm.v1.worker.tpu_model_runner.xm" ), \
66- mock .patch ("vllm.v1.worker.tpu_model_runner.xr" ):
67- return TPUModelRunner (vllm_config , device )
51+ return TPUModelRunner (vllm_config , device )
6852
6953
70- @pytest .fixture ( autouse = True , scope = "session" )
71- def cleanup_patches ():
72- yield
73- torch_xla_patcher . stop ()
74- pallas_attention_backend_patcher . stop ( )
54+ @pytest .fixture
55+ def model_runner ():
56+ # Patchers have already been started at module level.
57+ vllm_config = get_vllm_config ()
58+ return get_model_runner ( vllm_config )
7559
7660
7761def _schedule_new_request (* req_ids : str ) -> SchedulerOutput :
@@ -370,12 +354,14 @@ def test_get_req_paddings():
370354 assert _get_req_paddings (8 , 36 ) == [8 , 16 , 32 , 36 ]
371355
372356
373- @ pytest . mark . skip ( reason = "Test is broken on TPU when it's added." )
374- def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order ( ):
357+ def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order (
358+ model_runner ):
375359 layer_0 = "model.layers.0.self_attn.attn"
376360 layer_1 = "model.layers.1.self_attn.attn"
377361 error_msg = f"{ layer_1 } must come before the current layer"
378- with pytest .raises (ValueError , match = error_msg ):
362+ vllm_config = model_runner .vllm_config
363+ with pytest .raises (ValueError , match = error_msg ), \
364+ set_current_vllm_config (vllm_config ):
379365 fwd_context = {
380366 # initialization below will fail because target layer is invalid;
381367 # the target layer needs to come before layer 1
@@ -399,13 +385,14 @@ def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
399385 assert fwd_context is not None
400386
401387
402- @pytest .mark .skip (reason = "Test is broken on TPU when it's added." )
403- def test_init_kv_cache_with_kv_sharing_target_layer_not_exist ():
388+ def test_init_kv_cache_with_kv_sharing_target_layer_not_exist (model_runner ):
404389 layer_0 = "model.layers.0.self_attn.attn"
405390 layer_1 = "model.layers.1.self_attn.attn"
406391 invalid_layer = "model.layers.0.cross_attn.attn"
407392 error_msg = f"{ invalid_layer } is not a valid Attention layer in the model"
408- with pytest .raises (ValueError , match = error_msg ):
393+ vllm_config = model_runner .vllm_config
394+ with pytest .raises (ValueError , match = error_msg ), \
395+ set_current_vllm_config (vllm_config ):
409396 fwd_context = {
410397 layer_0 :
411398 Attention (
@@ -428,12 +415,13 @@ def test_init_kv_cache_with_kv_sharing_target_layer_not_exist():
428415 assert fwd_context is not None
429416
430417
431- @pytest .mark .skip (reason = "Test is broken on TPU when it's added." )
432- def test_init_kv_cache_with_kv_sharing_target_same_as_current ():
418+ def test_init_kv_cache_with_kv_sharing_target_same_as_current (model_runner ):
433419 layer_0 = "model.layers.0.self_attn.attn"
434420 layer_1 = "model.layers.1.self_attn.attn"
435421 error_msg = f"{ layer_1 } cannot be the same as the current layer"
436- with pytest .raises (ValueError , match = error_msg ):
422+ vllm_config = model_runner .vllm_config
423+ with pytest .raises (ValueError , match = error_msg ), \
424+ set_current_vllm_config (vllm_config ):
437425 fwd_context = {
438426 # initialization below will fail because target layer is invalid;
439427 # the target layer needs to come before layer 1
@@ -457,11 +445,10 @@ def test_init_kv_cache_with_kv_sharing_target_same_as_current():
457445 assert fwd_context is not None
458446
459447
460- @pytest .mark .skip (reason = "Test is broken on TPU when it's added." )
461- def test_init_kv_cache_without_kv_sharing (model_runner ):
448+ def test_init_kv_cache_without_kv_sharing ():
462449 layer_0 = "model.layers.0.self_attn.attn"
463450 layer_1 = "model.layers.1.self_attn.attn"
464- vllm_config = model_runner . vllm_config
451+ vllm_config = get_vllm_config ()
465452 with set_current_vllm_config (vllm_config ):
466453 fwd_context = {
467454 layer_0 :
@@ -482,33 +469,38 @@ def test_init_kv_cache_without_kv_sharing(model_runner):
482469 # suppress var not used error
483470 assert fwd_context is not None
484471 # Set high context length to test max context length estimation
485- vllm_config .model_config .max_model_len = 3_000_000
472+ vllm_config .model_config .max_model_len = 1_000_000
486473 vllm_ctx = vllm_config .compilation_config .static_forward_context
474+ model_runner = get_model_runner (vllm_config )
487475 kv_cache_spec = model_runner .get_kv_cache_spec ()
488476 assert len (kv_cache_spec ) == 2
489477 assert len (model_runner .shared_kv_cache_layers ) == 0
490478
491479 available_memory = 20 * GiB_bytes
492- # page size for layer 0's kv_cache_spec is 32KB
493- num_expected_blocks = 327680 # 20GB / 32KB / 2 (num layers)
480+ # page size for each layer KV can be calculated as
481+ # 2 (non-MLA) * 8 (num_heads) * 128 (head_dim)
482+ # * 2 (bfloat16, kv_cache dtype) * 128 (block_size) = 512KB
483+ num_expected_blocks = 20480 # 20GB / 512KB / 2 (num layers)
494484 kv_cache_config = get_kv_cache_config (vllm_config , kv_cache_spec ,
495485 available_memory )
496486 assert kv_cache_config .num_blocks == num_expected_blocks
497- assert len (kv_cache_config .tensors ) == 2
498- assert kv_cache_config .tensors [ layer_0 ].size == available_memory // 2
499- assert kv_cache_config .tensors [ layer_1 ].size == available_memory // 2
487+ assert len (kv_cache_config .kv_cache_tensors ) == 2
488+ assert kv_cache_config .kv_cache_tensors [ 0 ].size == available_memory // 2
489+ assert kv_cache_config .kv_cache_tensors [ 1 ].size == available_memory // 2
500490
501491 max_context_len = \
502492 estimate_max_model_len (vllm_config , kv_cache_spec , 5 * GiB_bytes )
503493 # max context len with KV sharing should be 2x as large as without
504- assert max_context_len == 1310720
494+ # max_context_len = available_memory / (page_size / block_size) / num_caches
495+ # max_context_len = 5GB / (512KB / 128) / 2 = 655360
496+ assert max_context_len == 655360
505497
506498 # important: override tensor size to prevent large mem alloc during test
507- # this will only allocate 2 block worth of memory (2 * 32kb )
499+ # this will only allocate 2 block worth of memory (2 * 512kb )
508500 kv_cache_config .num_blocks = 1
509- for layer in kv_cache_config .tensors :
510- kv_cache_config . tensors [ layer ]. size = \
511- kv_cache_spec [layer ] .page_size_bytes
501+ for kv_cache_tensor in kv_cache_config .kv_cache_tensors :
502+ kv_cache_tensor . size = (
503+ kv_cache_spec [kv_cache_tensor . shared_by [ 0 ]] .page_size_bytes )
512504
513505 model_runner .initialize_kv_cache (kv_cache_config )
514506
@@ -524,11 +516,10 @@ def test_init_kv_cache_without_kv_sharing(model_runner):
524516 assert kv_cache_config .kv_cache_groups [0 ].layer_names [1 ] == layer_1
525517
526518
527- @pytest .mark .skip (reason = "Test is broken on TPU when it's added." )
528- def test_init_kv_cache_with_kv_sharing_valid (model_runner ):
519+ def test_init_kv_cache_with_kv_sharing_valid ():
529520 layer_0 = "model.layers.0.self_attn.attn"
530521 layer_1 = "model.layers.1.self_attn.attn"
531- vllm_config = model_runner . vllm_config
522+ vllm_config = get_vllm_config ()
532523 with set_current_vllm_config (vllm_config ):
533524 fwd_context = {
534525 layer_0 :
@@ -552,33 +543,34 @@ def test_init_kv_cache_with_kv_sharing_valid(model_runner):
552543 # Set high context length to test max context length estimation
553544 vllm_config .model_config .max_model_len = 3_000_000
554545 vllm_ctx = vllm_config .compilation_config .static_forward_context
546+ model_runner = get_model_runner (vllm_config )
555547 kv_cache_spec = model_runner .get_kv_cache_spec ()
556548 assert len (kv_cache_spec ) == 1
557549 assert layer_0 in kv_cache_spec
558550 assert model_runner .shared_kv_cache_layers [layer_1 ] == layer_0
559551
560552 available_memory = 20 * GiB_bytes
561- # page size for layer 0's kv_cache_spec is 32KB
553+ # page size for layer 0's kv_cache_spec is 512KB
562554 # with KV sharing, we can allocate (available_mem//page_size//1) blocks
563555 # which is twice as many as without KV sharing
564- num_expected_blocks = 655360 # 20GB / 32KB
556+ num_expected_blocks = 2 * 20480 # 20GB / 512KB
565557 kv_cache_config = get_kv_cache_config (vllm_config , kv_cache_spec ,
566558 available_memory )
567559 assert kv_cache_config .num_blocks == num_expected_blocks
568- assert len (kv_cache_config .tensors ) == 1
560+ assert len (kv_cache_config .kv_cache_tensors ) == 1
569561 # Each layer now has twice the available memory for KV cache
570562 # compared to no KV sharing
571- assert kv_cache_config .tensors [ layer_0 ].size == available_memory
563+ assert kv_cache_config .kv_cache_tensors [ 0 ].size == available_memory
572564
573565 max_context_len = \
574566 estimate_max_model_len (vllm_config , kv_cache_spec , 5 * GiB_bytes )
575567 # max context len with KV sharing should be 2x as large as without
576- assert max_context_len == 2 * 1310720
568+ assert max_context_len == ( 2 * 655360 )
577569
578570 # important: override tensor size to prevent large mem alloc during test
579- # this will only allocate 1 block worth of memory (32kb )
571+ # this will only allocate 1 block worth of memory (512kb )
580572 kv_cache_config .num_blocks = 1
581- kv_cache_config .tensors [ layer_0 ].size = \
573+ kv_cache_config .kv_cache_tensors [ 0 ].size = \
582574 kv_cache_spec [layer_0 ].page_size_bytes
583575
584576 model_runner .initialize_kv_cache (kv_cache_config )
0 commit comments