1616NUM_HEADS  =  [8 ]  # Arbitrary values for testing 
1717HEAD_SIZES  =  [64 , 80 , 120 , 256 ]
1818BLOCK_SIZES  =  [8 , 16 , 32 ]
19+ CACHE_LAYOUTS  =  ["NHD" , "HND" ]
1920
2021# Parameters for MLA tests. 
2122KV_LORA_RANKS  =  [512 ]
@@ -220,6 +221,7 @@ def test_reshape_and_cache(
220221@pytest .mark .parametrize ("seed" , SEEDS ) 
221222@pytest .mark .parametrize ("device" , CUDA_DEVICES ) 
222223@pytest .mark .parametrize ("kv_cache_dtype" , KV_CACHE_DTYPE ) 
224+ @pytest .mark .parametrize ("kv_cache_layout" , CACHE_LAYOUTS ) 
223225@torch .inference_mode () 
224226def  test_reshape_and_cache_flash (
225227    kv_cache_factory_flashinfer ,
@@ -232,6 +234,7 @@ def test_reshape_and_cache_flash(
232234    seed : int ,
233235    device : str ,
234236    kv_cache_dtype : str ,
237+     kv_cache_layout : str ,
235238) ->  None :
236239    current_platform .seed_everything (seed )
237240    torch .set_default_device (device )
@@ -242,7 +245,6 @@ def test_reshape_and_cache_flash(
242245    slot_mapping  =  torch .tensor (slot_mapping_lst ,
243246                                dtype = torch .long ,
244247                                device = device )
245- 
246248    qkv  =  torch .randn (num_tokens ,
247249                      3 ,
248250                      num_heads ,
@@ -261,44 +263,56 @@ def test_reshape_and_cache_flash(
261263        kv_cache_dtype ,
262264        dtype ,
263265        device = device ,
266+         cache_layout = kv_cache_layout ,
264267    )
265-     key_cache , value_cache  =  key_caches [0 ].contiguous (
266-     ), value_caches [0 ].contiguous ()
268+     key_cache , value_cache  =  key_caches [0 ], value_caches [0 ]
267269    del  key_caches 
268270    del  value_caches 
269271
270272    k_scale  =  (key .amax () /  64.0 ).to (torch .float32 )
271273    v_scale  =  (value .amax () /  64.0 ).to (torch .float32 )
272274
275+     def  permute_and_compact (x ):
276+         y  =  x  if  kv_cache_layout  ==  "NHD"  else  x .permute (0 , 2 , 1 , 3 )
277+         return  y .contiguous ()
278+ 
279+     key_cache_compact  =  permute_and_compact (key_cache )
280+     value_cache_compact  =  permute_and_compact (value_cache )
281+ 
273282    # Clone the KV caches. 
274283    if  kv_cache_dtype  ==  "fp8" :
275-         cloned_key_cache  =  torch .empty_like (key_cache , dtype = torch .float16 )
276-         ops .convert_fp8 (cloned_key_cache , key_cache , k_scale .item (),
277-                         kv_cache_dtype )
278-         cloned_value_cache  =  torch .empty_like (value_cache , dtype = torch .float16 )
279-         ops .convert_fp8 (cloned_value_cache , value_cache , v_scale .item (),
284+         cloned_key_cache  =  torch .empty_like (key_cache_compact ,
285+                                             dtype = torch .float16 )
286+         ops .convert_fp8 (cloned_key_cache , key_cache_compact , k_scale .item (),
280287                        kv_cache_dtype )
288+         cloned_value_cache  =  torch .empty_like (value_cache_compact ,
289+                                               dtype = torch .float16 )
290+         ops .convert_fp8 (cloned_value_cache , value_cache_compact ,
291+                         v_scale .item (), kv_cache_dtype )
281292    else :
282-         cloned_key_cache  =  key_cache .clone ()
283-         cloned_value_cache  =  value_cache .clone ()
284- 
293+         cloned_key_cache  =  key_cache_compact .clone ()
294+         cloned_value_cache  =  value_cache_compact .clone ()
285295    # Call the reshape_and_cache kernel. 
286296    opcheck (torch .ops ._C_cache_ops .reshape_and_cache_flash ,
287297            (key , value , key_cache , value_cache , slot_mapping , kv_cache_dtype ,
288298             k_scale , v_scale ),
289299            cond = (head_size  ==  HEAD_SIZES [0 ]))
290300    ops .reshape_and_cache_flash (key , value , key_cache , value_cache ,
291301                                slot_mapping , kv_cache_dtype , k_scale , v_scale )
302+     key_cache_compact  =  permute_and_compact (key_cache )
303+     value_cache_compact  =  permute_and_compact (value_cache )
292304
293305    if  kv_cache_dtype  ==  "fp8" :
294-         result_key_cache  =  torch .empty_like (key_cache , dtype = torch .float16 )
306+         result_key_cache  =  torch .empty_like (key_cache_compact ,
307+                                             dtype = torch .float16 )
295308        ops .convert_fp8 (result_key_cache ,
296-                         key_cache ,
309+                         key_cache_compact ,
297310                        k_scale .item (),
298311                        kv_dtype = kv_cache_dtype )
299-         result_value_cache  =  torch .empty_like (value_cache , dtype = torch .float16 )
312+         result_value_cache  =  torch .empty_like (value_cache_compact ,
313+                                               dtype = torch .float16 )
300314        ops .convert_fp8 (result_value_cache ,
301-                         value_cache ,
315+                         value_cache_compact ,
302316                        v_scale .item (),
303317                        kv_dtype = kv_cache_dtype )
304318
@@ -310,8 +324,12 @@ def test_reshape_and_cache_flash(
310324    for  i  in  range (num_tokens ):
311325        block_idx  =  block_indicies_lst [i ]
312326        block_offset  =  block_offsets_lst [i ]
313-         cloned_key_cache [block_idx , block_offset , :, :] =  key [i ]
314-         cloned_value_cache [block_idx , block_offset , :, :] =  value [i ]
327+         if  kv_cache_layout  ==  "NHD" :
328+             cloned_key_cache [block_idx , block_offset , :, :] =  key [i ]
329+             cloned_value_cache [block_idx , block_offset , :, :] =  value [i ]
330+         else :
331+             cloned_key_cache [block_idx , :, block_offset , :] =  key [i ]
332+             cloned_value_cache [block_idx , :, block_offset , :] =  value [i ]
315333
316334    if  kv_cache_dtype  ==  "fp8" :
317335        torch .testing .assert_close (result_key_cache ,
@@ -323,8 +341,8 @@ def test_reshape_and_cache_flash(
323341                                   atol = 0.001 ,
324342                                   rtol = 0.1 )
325343    else :
326-         torch .testing .assert_close (key_cache , cloned_key_cache )
327-         torch .testing .assert_close (value_cache , cloned_value_cache )
344+         torch .testing .assert_close (key_cache_compact , cloned_key_cache )
345+         torch .testing .assert_close (value_cache_compact , cloned_value_cache )
328346
329347
330348@pytest .mark .parametrize ("direction" , COPYING_DIRECTION ) 
0 commit comments