1919from vllm .logger import init_logger
2020from vllm .v1 .attention .backends .flash_attn import use_cascade_attention
2121from vllm .v1 .attention .backends .utils import (AttentionMetadataBuilder ,
22- CommonAttentionMetadata )
22+ CommonAttentionMetadata ,
23+ get_kv_cache_layout )
2324from vllm .v1 .kv_cache_interface import AttentionSpec
2425from vllm .v1 .worker .block_table import BlockTable
2526
@@ -66,6 +67,19 @@ def get_kv_cache_shape(
6667 ) -> tuple [int , ...]:
6768 return (num_blocks , 2 , block_size , num_kv_heads , head_size )
6869
70+ @staticmethod
71+ def get_kv_cache_stride_order () -> tuple [int , ...]:
72+ # `stride_order` indicates the permutation that gets us from
73+ # `get_kv_cache_shape` to the actual memory layout we want.
74+ cache_layout = get_kv_cache_layout ()
75+ if cache_layout == "NHD" :
76+ stride_order = (0 , 1 , 2 , 3 , 4 )
77+ elif cache_layout == "HND" :
78+ stride_order = (0 , 1 , 3 , 2 , 4 )
79+ else :
80+ raise ValueError (f"Unknown cache layout format { cache_layout } ." )
81+ return stride_order
82+
6983
7084@dataclass
7185class PerLayerParameters :
@@ -290,7 +304,7 @@ def _get_workspace_buffer(self):
290304 def _get_prefill_wrapper (self ):
291305 if self ._prefill_wrapper is None :
292306 self ._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper (
293- self ._get_workspace_buffer (), "NHD" )
307+ self ._get_workspace_buffer (), get_kv_cache_layout () )
294308 return self ._prefill_wrapper
295309
296310 def _get_decode_wrapper (self ):
@@ -303,14 +317,14 @@ def _get_decode_wrapper(self):
303317 num_qo_heads // num_kv_heads > 4 )
304318 self ._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper (
305319 self ._get_workspace_buffer (),
306- "NHD" ,
320+ get_kv_cache_layout () ,
307321 use_tensor_cores = use_tensor_cores )
308322 return self ._decode_wrapper
309323
310324 def _get_cascade_wrapper (self ):
311325 if self ._cascade_wrapper is None :
312326 self ._cascade_wrapper = MultiLevelCascadeAttentionWrapper (
313- 2 , self ._get_workspace_buffer (), "NHD" )
327+ 2 , self ._get_workspace_buffer (), get_kv_cache_layout () )
314328 return self ._cascade_wrapper
315329
316330 def _plan (self , attn_metadata : FlashInferMetadata ):
@@ -620,6 +634,7 @@ def forward(
620634 num_decode_tokens = attn_metadata .num_decode_tokens
621635 num_prefill_tokens = attn_metadata .num_prefill_tokens
622636
637+ stride_order = FlashInferBackend .get_kv_cache_stride_order ()
623638 # Regular attention (common case).
624639 # Decodes are at the front and prefills are at the back,
625640 # according to reorder_batch()
@@ -634,7 +649,7 @@ def forward(
634649 assert prefill_wrapper ._sm_scale == self .scale
635650 prefill_wrapper .run (
636651 prefill_query ,
637- kv_cache ,
652+ kv_cache . permute ( * stride_order ) ,
638653 k_scale = layer ._k_scale_float ,
639654 v_scale = layer ._v_scale_float ,
640655 out = output [num_decode_tokens :],
@@ -650,7 +665,7 @@ def forward(
650665 assert decode_wrapper ._sm_scale == self .scale
651666 decode_wrapper .run (
652667 decode_query ,
653- kv_cache ,
668+ kv_cache . permute ( * stride_order ) ,
654669 k_scale = layer ._k_scale_float ,
655670 v_scale = layer ._v_scale_float ,
656671 out = output [:num_decode_tokens ],
0 commit comments