1010
1111def  _kv_cache_update_kernel (
1212    # Prefetch 
13-     slices_ref ,  # [num_slices, 3] 
13+     slices_ref ,  # [3, num_slices], list of (kv_cache_start, new_kv_start, 
14+     # slice_len) 
1415    # Input 
15-     new_kv_hbm_ref ,  # [tokens, num_combined_kv_heads, head_dim] 
16-     kv_cache_hbm_ref ,
16+     new_kv_hbm_ref ,  # [num_tokens, num_combined_kv_heads, head_dim] 
17+     kv_cache_hbm_ref ,  # [total_num_pages * page_size, num_combined_kv_heads, 
18+     # head_dim] 
1719    # Output 
1820    _ ,  # [total_num_pages * page_size, num_combined_kv_heads, head_dim] 
1921    # Scratch 
20-     scratch ,  # [block_size, page_size, num_combined_kv_heads, head_dim] 
22+     scratch ,  # [num_slices_per_block, page_size, num_combined_kv_heads, 
23+     # head_dim] 
2124    sem ,
2225):
2326    async_copies  =  []
2427    block_idx  =  pl .program_id (0 )
25-     block_size  =  scratch .shape [0 ]
28+     num_slices_per_block  =  scratch .shape [0 ]
2629
2730    # Copy from new_kv_hbm_ref to scratch 
28-     for  i  in  range (block_size ):
29-         offset_i  =  i  +  block_idx  *  block_size 
30-         new_kv_start  =  slices_ref [offset_i ,  1 ]
31-         length  =  slices_ref [offset_i ,  2 ]
31+     for  i  in  range (num_slices_per_block ):
32+         offset_i  =  i  +  block_idx  *  num_slices_per_block 
33+         new_kv_start  =  slices_ref [1 ,  offset_i ]
34+         length  =  slices_ref [2 ,  offset_i ]
3235        async_copy  =  pltpu .make_async_copy (
3336            new_kv_hbm_ref .at [pl .ds (new_kv_start , length ), ...],
3437            scratch .at [i , pl .ds (0 , length ), ...],
@@ -42,10 +45,10 @@ def _kv_cache_update_kernel(
4245
4346    # Copy from scratch to kv_cache_hbm_ref 
4447    async_copies .clear ()
45-     for  i  in  range (block_size ):
46-         offset_i  =  i  +  block_idx  *  block_size 
47-         kv_cache_start  =  slices_ref [offset_i ,  0 ]
48-         length  =  slices_ref [offset_i ,  2 ]
48+     for  i  in  range (num_slices_per_block ):
49+         offset_i  =  i  +  block_idx  *  num_slices_per_block 
50+         kv_cache_start  =  slices_ref [0 ,  offset_i ]
51+         length  =  slices_ref [2 ,  offset_i ]
4952        async_copy  =  pltpu .make_async_copy (
5053            scratch .at [i , pl .ds (0 , length ), ...],
5154            kv_cache_hbm_ref .at [pl .ds (kv_cache_start , length ), ...],
@@ -59,23 +62,25 @@ def _kv_cache_update_kernel(
5962
6063@functools .partial ( 
6164    jax .jit , 
62-     static_argnames = ["page_size" , "block_size " ], 
65+     static_argnames = ["page_size" , "num_slices_per_block " ], 
6366) 
6467def  kv_cache_update (
6568    new_kv : jax .Array ,  # [total_num_token, num_combined_kv_heads, head_dim] 
6669    slices : jax .
67-     Array ,  # [num_slices, 3 ], list of (kv_cache_start, new_kv_start, slice_len) 
70+     Array ,  # [3, slices ], list of (kv_cache_start, new_kv_start, slice_len) 
6871    kv_cache : jax .
6972    Array ,  # [total_num_pages * page_size, num_combined_kv_heads, head_dim] 
7073    * ,
7174    page_size : int  =  32 ,
72-     block_size : int  =  8 ,
75+     num_slices_per_block : int  =  8 ,
7376):
74-     assert  slices .shape [0 ] %  block_size  ==  0 
77+     assert  slices .shape [1 ] %  num_slices_per_block  ==  0 
7578    _ , num_combined_kv_heads , head_dim  =  new_kv .shape 
7679    assert  kv_cache .shape [1 ] ==  num_combined_kv_heads 
7780    assert  kv_cache .shape [2 ] ==  head_dim 
7881    assert  head_dim  %  128  ==  0 
82+     # TODO: Add dynamic check to make sure that the all the slice lengths are 
83+     # smaller or equal to page_size 
7984
8085    in_specs  =  [
8186        pl .BlockSpec (memory_space = pltpu .TPUMemorySpace .ANY ),
@@ -87,7 +92,7 @@ def kv_cache_update(
8792
8893    scalar_prefetches  =  [slices ]
8994    scratch  =  pltpu .VMEM (
90-         (block_size , page_size , num_combined_kv_heads , head_dim ),
95+         (num_slices_per_block , page_size , num_combined_kv_heads , head_dim ),
9196        new_kv .dtype ,
9297    )
9398
@@ -102,7 +107,7 @@ def kv_cache_update(
102107            num_scalar_prefetch = len (scalar_prefetches ),
103108            in_specs = in_specs ,
104109            out_specs = out_specs ,
105-             grid = (slices .shape [0 ] //  block_size , ),
110+             grid = (slices .shape [1 ] //  num_slices_per_block , ),
106111            scratch_shapes = scratch_shapes ,
107112        ),
108113        out_shape = out_shape ,
0 commit comments