55from typing import Any , Optional
66
77import torch
8- # Required to register custom ops.
8+ import torch_xla . core . xla_builder as xb
99import torch_xla .experimental .custom_kernel # noqa: F401
10+ # Required to register custom ops.
11+ from torch .library import impl
12+ from torch_xla ._internal .jax_workarounds import requires_jax
13+ from torch_xla .experimental .custom_kernel import XLA_LIB
1014
1115from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
1216 AttentionLayer , AttentionType )
@@ -107,6 +111,7 @@ class PallasMetadata:
107111 context_lens : torch .Tensor
108112 query_start_loc : torch .Tensor
109113 num_seqs : torch .Tensor
114+ num_slices_per_kv_cache_update_block : int
110115
111116
112117class PallasAttentionBackendImpl (AttentionImpl ):
@@ -212,7 +217,9 @@ def forward(
212217 # Write input keys and values to the KV cache.
213218 # Skip this if sharing KV cache with an earlier attention layer.
214219 slot_mapping = attn_metadata .slot_mapping
215- write_to_kv_cache (key , value , kv_cache , slot_mapping )
220+ write_to_kv_cache (
221+ key , value , kv_cache , slot_mapping ,
222+ attn_metadata .num_slices_per_kv_cache_update_block )
216223
217224 output = torch .ops .xla .ragged_paged_attention (
218225 query ,
@@ -244,16 +251,17 @@ def write_to_kv_cache(
244251 value : torch .Tensor ,
245252 kv_cache : torch .Tensor ,
246253 slot_mapping : torch .Tensor ,
254+ num_slices_per_kv_cache_update_block : int ,
247255) -> None :
248256 """ Write the key and values to the KV cache.
249257
250258 Args:
251259 key: shape = [num_tokens, num_kv_heads * head_size]
252260 value: shape = [num_tokens, num_kv_heads * head_size]
253261 kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size]
254-
262+ num_slices_per_kv_cache_update_block: int
255263 """
256- _ , _ , num_combined_kv_heads , head_size = kv_cache .shape
264+ _ , page_size , num_combined_kv_heads , head_size = kv_cache .shape
257265 head_size = cdiv (head_size ,
258266 TPU_HEAD_SIZE_ALIGNMENT ) * TPU_HEAD_SIZE_ALIGNMENT
259267 kv = torch .cat ([key , value ], axis = - 1 ).reshape (- 1 , num_combined_kv_heads ,
@@ -262,4 +270,41 @@ def write_to_kv_cache(
262270 torch .ops .xla .dynamo_set_buffer_donor_ (kv_cache , True )
263271
264272 kv_cache = kv_cache .flatten (0 , 1 )
265- kv_cache .index_copy_ (0 , slot_mapping , kv )
273+ new_kv_cache = torch .ops .xla .kv_cache_update_op (
274+ kv , slot_mapping , kv_cache , page_size ,
275+ num_slices_per_kv_cache_update_block )
276+ # NOTE: the in-place copy will be optimized away by XLA compiler.
277+ kv_cache .copy_ (new_kv_cache )
278+
279+
280+ @requires_jax
281+ def kv_cache_update_op_impl (kv : torch .Tensor , slot_mapping : torch .Tensor ,
282+ kv_cache : torch .Tensor , page_size : int ,
283+ num_slices_per_block : int ):
284+ from vllm .attention .ops .pallas_kv_cache_update import kv_cache_update
285+ new_kv_cache = xb .call_jax (kv_cache_update , (kv , slot_mapping , kv_cache ), {
286+ "page_size" : page_size ,
287+ "num_slices_per_block" : num_slices_per_block
288+ })
289+ return new_kv_cache
290+
291+
292+ XLA_LIB .define (
293+ "kv_cache_update_op(Tensor kv, Tensor slot_mapping, Tensor kv_cache, "
294+ "int page_size, int num_slices_per_block) -> Tensor" , )
295+
296+
297+ @impl (XLA_LIB , "kv_cache_update_op" , "XLA" )
298+ def kv_cache_update_op_xla (kv : torch .Tensor , slot_mapping : torch .Tensor ,
299+ kv_cache : torch .Tensor , page_size : int ,
300+ num_slices_per_block : int ) -> torch .Tensor :
301+ new_kv_cache = kv_cache_update_op_impl (kv , slot_mapping , kv_cache ,
302+ page_size , num_slices_per_block )
303+ return new_kv_cache
304+
305+
306+ @impl (XLA_LIB , "kv_cache_update_op" , "CompositeExplicitAutograd" )
307+ def kv_cache_update_op_non_xla (kv : torch .Tensor , slot_mapping : torch .Tensor ,
308+ kv_cache : torch .Tensor , page_size : int ,
309+ num_slices_per_block : int ) -> torch .Tensor :
310+ return kv_cache
0 commit comments