88from tests .kernels .utils import DEFAULT_OPCHECK_TEST_UTILS , opcheck
99from vllm import _custom_ops as ops
1010from vllm .platforms import current_platform
11- from vllm .utils import align_to_256bytes
1211
1312COPYING_DIRECTION = [('cuda' , 'cpu' ), ('cuda' , 'cuda' ), ('cpu' , 'cuda' )]
1413DTYPES = [torch .half , torch .bfloat16 , torch .float ]
@@ -450,22 +449,13 @@ def _create_mla_cache(
450449 dtype : torch .dtype ,
451450 kv_cache_dtype : str ,
452451 device : str ,
453- align_cache : bool ,
454452) -> torch .Tensor :
455453 cache_dtype = torch .uint8 if kv_cache_dtype == "fp8" else dtype
456-
457- if align_cache :
458- alloc_entry_size = align_to_256bytes (entry_size , cache_dtype )
459- alloc_shape = (num_blocks , block_size , alloc_entry_size )
460- cache_full = torch .zeros (alloc_shape , dtype = cache_dtype , device = device )
461- cache = cache_full [..., :entry_size ]
462- else :
463- cache = torch .zeros (num_blocks ,
464- block_size ,
465- entry_size ,
466- dtype = cache_dtype ,
467- device = device )
468- return cache
454+ return torch .zeros (num_blocks ,
455+ block_size ,
456+ entry_size ,
457+ dtype = cache_dtype ,
458+ device = device )
469459
470460
471461def _fill_mla_cache (cache : torch .Tensor , kv_cache_dtype : str ):
@@ -488,7 +478,6 @@ def _fill_mla_cache(cache: torch.Tensor, kv_cache_dtype: str):
488478@pytest .mark .parametrize ("seed" , SEEDS )
489479@pytest .mark .parametrize ("device" , CUDA_DEVICES )
490480@pytest .mark .parametrize ("kv_cache_dtype" , KV_CACHE_DTYPE )
491- @pytest .mark .parametrize ("align_cache" , [False ])
492481@torch .inference_mode ()
493482def test_concat_and_cache_mla (
494483 kv_lora_rank : int ,
@@ -500,7 +489,6 @@ def test_concat_and_cache_mla(
500489 seed : int ,
501490 device : str ,
502491 kv_cache_dtype : str ,
503- align_cache : bool ,
504492) -> None :
505493 current_platform .seed_everything (seed )
506494 torch .set_default_device (device )
@@ -520,7 +508,7 @@ def test_concat_and_cache_mla(
520508
521509 scale = torch .tensor (0.1 , dtype = torch .float32 , device = device )
522510 kv_cache = _create_mla_cache (num_blocks , block_size , entry_size , dtype ,
523- kv_cache_dtype , device , align_cache )
511+ kv_cache_dtype , device )
524512 ref_temp = torch .zeros (* kv_cache .shape , dtype = dtype , device = device )
525513
526514 for i in range (num_tokens ):
@@ -576,7 +564,6 @@ def test_concat_and_cache_mla(
576564@pytest .mark .parametrize ("seed" , SEEDS )
577565@pytest .mark .parametrize ("device" , CUDA_DEVICES )
578566@pytest .mark .parametrize ("kv_cache_dtype" , KV_CACHE_DTYPE )
579- @pytest .mark .parametrize ("align_cache" , [False , True ])
580567@torch .inference_mode ()
581568def test_copy_blocks_mla (
582569 kv_lora_rank : int ,
@@ -588,7 +575,6 @@ def test_copy_blocks_mla(
588575 seed : int ,
589576 device : str ,
590577 kv_cache_dtype : str ,
591- align_cache : bool ,
592578) -> None :
593579 current_platform .seed_everything (seed )
594580 torch .set_default_device (device )
@@ -598,7 +584,7 @@ def test_copy_blocks_mla(
598584 kv_caches = []
599585 for _ in range (num_layers ):
600586 kv_cache = _create_mla_cache (num_blocks , block_size , entry_size , dtype ,
601- kv_cache_dtype , device , align_cache )
587+ kv_cache_dtype , device )
602588 _fill_mla_cache (kv_cache , kv_cache_dtype = kv_cache_dtype )
603589 kv_caches .append (kv_cache )
604590
@@ -642,7 +628,6 @@ def test_copy_blocks_mla(
642628@pytest .mark .parametrize ("seed" , SEEDS )
643629@pytest .mark .parametrize ("device" , CUDA_DEVICES )
644630@pytest .mark .parametrize ("kv_cache_dtype" , KV_CACHE_DTYPE )
645- @pytest .mark .parametrize ("align_cache" , [False , True ])
646631@torch .inference_mode ()
647632def test_swap_blocks_mla (
648633 kv_lora_rank : int ,
@@ -653,17 +638,16 @@ def test_swap_blocks_mla(
653638 seed : int ,
654639 device : str ,
655640 kv_cache_dtype : str ,
656- align_cache : bool ,
657641) -> None :
658642 current_platform .seed_everything (seed )
659643 torch .set_default_device (device )
660644
661645 entry_size = kv_lora_rank + qk_rope_head_dim
662646
663647 src_cache = _create_mla_cache (num_blocks , block_size , entry_size , dtype ,
664- kv_cache_dtype , device , align_cache )
648+ kv_cache_dtype , device )
665649 dst_cache = _create_mla_cache (num_blocks , block_size , entry_size , dtype ,
666- kv_cache_dtype , device , align_cache )
650+ kv_cache_dtype , device )
667651
668652 _fill_mla_cache (src_cache , kv_cache_dtype )
669653 _fill_mla_cache (dst_cache , kv_cache_dtype )
@@ -704,15 +688,14 @@ def test_swap_blocks_mla(
704688@pytest .mark .parametrize ("dtype" , [torch .float32 ])
705689@pytest .mark .parametrize ("kv_cache_dtype" ,
706690 ["auto" ]) # You can also test "fp8" if needed.
707- @pytest .mark .parametrize ("align_cache" , [True , False ])
708691@pytest .mark .parametrize ("device" , CUDA_DEVICES )
709692@torch .inference_mode ()
710693def test_gather_cache_mla (kv_lora_rank , qk_rope_head_dim , block_size ,
711694 num_blocks , max_seq_len , batch_size , dtype ,
712- kv_cache_dtype , align_cache , device ):
695+ kv_cache_dtype , device ):
713696 entry_size = kv_lora_rank + qk_rope_head_dim
714697 src_cache = _create_mla_cache (num_blocks , block_size , entry_size , dtype ,
715- kv_cache_dtype , device , align_cache )
698+ kv_cache_dtype , device )
716699 _fill_mla_cache (src_cache , kv_cache_dtype = kv_cache_dtype )
717700
718701 seq_len_tensor = torch .randint (0 ,
0 commit comments