@@ -14,6 +14,7 @@ def _kv_cache_update_kernel(
1414    # Prefetch 
1515    slices_ref ,  # [3, padded_num_slices], list of (kv_cache_start, 
1616    # new_kv_start, slice_len) 
17+     num_slices_ref ,  # [1] 
1718    # Input 
1819    new_kv_hbm_ref ,  # [num_tokens, num_combined_kv_heads, head_dim] 
1920    kv_cache_hbm_ref ,  # [total_num_pages * page_size, num_combined_kv_heads, 
@@ -32,8 +33,10 @@ def _kv_cache_update_kernel(
3233    # Copy from new_kv_hbm_ref to scratch 
3334    for  i  in  range (num_slices_per_block ):
3435        offset_i  =  i  +  block_idx  *  num_slices_per_block 
35-         new_kv_start  =  slices_ref [1 , offset_i ]
36-         length  =  slices_ref [2 , offset_i ]
36+         new_kv_start  =  jax .lax .select (offset_i  <  num_slices_ref [0 ],
37+                                       slices_ref [1 , offset_i ], 0 )
38+         length  =  jax .lax .select (offset_i  <  num_slices_ref [0 ],
39+                                 slices_ref [2 , offset_i ], 0 )
3740        async_copy  =  pltpu .make_async_copy (
3841            new_kv_hbm_ref .at [pl .ds (new_kv_start , length ), ...],
3942            scratch .at [i , pl .ds (0 , length ), ...],
@@ -49,8 +52,10 @@ def _kv_cache_update_kernel(
4952    async_copies .clear ()
5053    for  i  in  range (num_slices_per_block ):
5154        offset_i  =  i  +  block_idx  *  num_slices_per_block 
52-         kv_cache_start  =  slices_ref [0 , offset_i ]
53-         length  =  slices_ref [2 , offset_i ]
55+         kv_cache_start  =  jax .lax .select (offset_i  <  num_slices_ref [0 ],
56+                                         slices_ref [0 , offset_i ], 0 )
57+         length  =  jax .lax .select (offset_i  <  num_slices_ref [0 ],
58+                                 slices_ref [2 , offset_i ], 0 )
5459        async_copy  =  pltpu .make_async_copy (
5560            scratch .at [i , pl .ds (0 , length ), ...],
5661            kv_cache_hbm_ref .at [pl .ds (kv_cache_start , length ), ...],
@@ -77,7 +82,6 @@ def kv_cache_update(
7782    page_size : int  =  32 ,
7883    num_slices_per_block : int  =  8 ,
7984):
80-     assert  slices .shape [1 ] %  num_slices_per_block  ==  0 
8185    _ , num_combined_kv_heads , head_dim  =  new_kv .shape 
8286    assert  kv_cache .shape [1 ] ==  num_combined_kv_heads 
8387    assert  kv_cache .shape [2 ] ==  head_dim 
@@ -93,7 +97,7 @@ def kv_cache_update(
9397    out_specs  =  [pl .BlockSpec (memory_space = pltpu .TPUMemorySpace .ANY )]
9498    out_shape  =  [jax .ShapeDtypeStruct (kv_cache .shape , dtype = kv_cache .dtype )]
9599
96-     scalar_prefetches  =  [slices ]
100+     scalar_prefetches  =  [slices ,  num_kv_update_slices ]
97101    scratch  =  pltpu .VMEM (
98102        (num_slices_per_block , page_size , num_combined_kv_heads , head_dim ),
99103        new_kv .dtype ,
0 commit comments