@@ -14,6 +14,7 @@ def _kv_cache_update_kernel(
1414 # Prefetch
1515 slices_ref , # [3, padded_num_slices], list of (kv_cache_start,
1616 # new_kv_start, slice_len)
17+ num_slices_ref , # [1]
1718 # Input
1819 new_kv_hbm_ref , # [num_tokens, num_combined_kv_heads, head_dim]
1920 kv_cache_hbm_ref , # [total_num_pages * page_size, num_combined_kv_heads,
@@ -32,8 +33,10 @@ def _kv_cache_update_kernel(
3233 # Copy from new_kv_hbm_ref to scratch
3334 for i in range (num_slices_per_block ):
3435 offset_i = i + block_idx * num_slices_per_block
35- new_kv_start = slices_ref [1 , offset_i ]
36- length = slices_ref [2 , offset_i ]
36+ new_kv_start = jax .lax .select (offset_i < num_slices_ref [0 ],
37+ slices_ref [1 , offset_i ], 0 )
38+ length = jax .lax .select (offset_i < num_slices_ref [0 ],
39+ slices_ref [2 , offset_i ], 0 )
3740 async_copy = pltpu .make_async_copy (
3841 new_kv_hbm_ref .at [pl .ds (new_kv_start , length ), ...],
3942 scratch .at [i , pl .ds (0 , length ), ...],
@@ -49,8 +52,10 @@ def _kv_cache_update_kernel(
4952 async_copies .clear ()
5053 for i in range (num_slices_per_block ):
5154 offset_i = i + block_idx * num_slices_per_block
52- kv_cache_start = slices_ref [0 , offset_i ]
53- length = slices_ref [2 , offset_i ]
55+ kv_cache_start = jax .lax .select (offset_i < num_slices_ref [0 ],
56+ slices_ref [0 , offset_i ], 0 )
57+ length = jax .lax .select (offset_i < num_slices_ref [0 ],
58+ slices_ref [2 , offset_i ], 0 )
5459 async_copy = pltpu .make_async_copy (
5560 scratch .at [i , pl .ds (0 , length ), ...],
5661 kv_cache_hbm_ref .at [pl .ds (kv_cache_start , length ), ...],
@@ -77,7 +82,6 @@ def kv_cache_update(
7782 page_size : int = 32 ,
7883 num_slices_per_block : int = 8 ,
7984):
80- assert slices .shape [1 ] % num_slices_per_block == 0
8185 _ , num_combined_kv_heads , head_dim = new_kv .shape
8286 assert kv_cache .shape [1 ] == num_combined_kv_heads
8387 assert kv_cache .shape [2 ] == head_dim
@@ -93,7 +97,7 @@ def kv_cache_update(
9397 out_specs = [pl .BlockSpec (memory_space = pltpu .TPUMemorySpace .ANY )]
9498 out_shape = [jax .ShapeDtypeStruct (kv_cache .shape , dtype = kv_cache .dtype )]
9599
96- scalar_prefetches = [slices ]
100+ scalar_prefetches = [slices , num_kv_update_slices ]
97101 scratch = pltpu .VMEM (
98102 (num_slices_per_block , page_size , num_combined_kv_heads , head_dim ),
99103 new_kv .dtype ,
0 commit comments