44from typing import Any , Dict , List , Optional , Tuple , Type
55
66import torch
7- import torch_xla .experimental .custom_kernel # Required to register custom ops.
7+ # Required to register custom ops.
8+ import torch_xla .experimental .custom_kernel # noqa: F401
89
910from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
10- AttentionLayer ,
11- AttentionMetadata , AttentionType )
11+ AttentionLayer , AttentionType )
1212from vllm .attention .backends .utils import CommonAttentionState
1313
14+ NUM_QUERIES_PER_BLOCK = 16
15+ NUM_KV_PAGES_PER_BLOCK = 128
16+
1417
1518class PallasAttentionBackend (AttentionBackend ):
1619
@@ -47,47 +50,23 @@ def swap_blocks(
4750 ) -> None :
4851 raise RuntimeError ("swap_blocks is not used for the TPU backend." )
4952
50- @torch .compile (backend = "openxla" )
51- @staticmethod
52- def copy_blocks (
53- kv_caches : List [Tuple [torch .Tensor , torch .Tensor ]],
54- src_to_dists : Tuple [torch .Tensor , torch .Tensor ],
55- ) -> None :
56- src_indices , dst_indices = src_to_dists
57- for k_cache , v_cache in kv_caches :
58- torch .ops .xla .dynamo_set_buffer_donor_ (k_cache , True )
59- k_cache [:, dst_indices ] = k_cache [:, src_indices ]
60- torch .ops .xla .dynamo_set_buffer_donor_ (v_cache , True )
61- v_cache [:, dst_indices ] = v_cache [:, src_indices ]
62-
6353
6454@dataclass
65- class PallasMetadata (AttentionMetadata ):
66-
67- # Currently, input sequences can only contain all prefills
68- # or all decoding.
69- block_tables : Optional [torch .Tensor ] = None
70- context_lens : Optional [torch .Tensor ] = None
71- effective_query_lens : Optional [torch .Tensor ] = None
72-
73- @property
74- def prefill_metadata (self ) -> Optional ["PallasMetadata" ]:
75- if self .num_prefills == 0 :
76- return None
77-
78- assert self .num_decode_tokens == 0
79- return self
80-
81- @property
82- def decode_metadata (self ) -> Optional ["PallasMetadata" ]:
83- if self .num_decode_tokens == 0 :
84- return None
85-
86- assert self .num_prefills == 0
87- assert self .num_prefill_tokens == 0
88- assert self .block_tables is not None
89- assert self .context_lens is not None
90- return self
55+ class PallasMetadata :
56+ # NOTE(sang): Definition of context_len, query_len, and seq_len.
57+ # |---------- N-1 iteration --------|
58+ # |---------------- N iteration ---------------------|
59+ # |- tokenA -|......................|-- newTokens ---|
60+ # |---------- context_len ----------|
61+ # |-------------------- seq_len ---------------------|
62+ # |-- query_len ---|
63+
64+ # Used in the PallasAttentionBackendImpl
65+ slot_mapping : torch .Tensor
66+ block_tables : torch .Tensor
67+ context_lens : torch .Tensor
68+ query_start_loc : torch .Tensor
69+ num_seqs : int
9170
9271
9372class PallasAttentionBackendImpl (AttentionImpl ):
@@ -105,10 +84,13 @@ def __init__(
10584 logits_soft_cap : Optional [float ] = None ,
10685 attn_type : str = AttentionType .DECODER ,
10786 ) -> None :
87+ if blocksparse_params is not None :
88+ raise ValueError ("Paged attention Pallas kernel does "
89+ "not support block-sparse attention." )
10890 self .num_heads = num_heads
10991 self .head_size = head_size
11092 self .scale = float (scale )
111- self .num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
93+ self .num_kv_heads = num_kv_heads
11294
11395 assert self .num_heads % self .num_kv_heads == 0
11496 self .num_queries_per_kv = self .num_heads // self .num_kv_heads
@@ -126,25 +108,6 @@ def __init__(
126108 raise NotImplementedError (
127109 "Attention logits soft-capping is not supported." )
128110
129- if torch_xla .tpu .version () < 4 :
130- raise NotImplementedError ("TPU version must be 4 or higher." )
131-
132- self .megacore_mode = None
133- tpu_env = torch_xla .tpu .get_tpu_env ()
134- tpu_type = (tpu_env .get ("ACCELERATOR_TYPE" , None )
135- or tpu_env .get ("TYPE" , None )
136- or tpu_env .get ("TPU_ACCELERATOR_TYPE" , None ))
137- assert tpu_type is not None
138- tpu_type = tpu_type .lower ()
139-
140- if (("lite" not in tpu_type ) and ("v6" not in tpu_type )):
141- if self .num_kv_heads % 2 == 0 :
142- self .megacore_mode = "kv_head"
143- else :
144- # NOTE(woosuk): If the batch size is not a multiple of 2, the
145- # megacore mode will be None.
146- self .megacore_mode = "batch"
147-
148111 if attn_type != AttentionType .DECODER :
149112 raise NotImplementedError ("Encoder self-attention and "
150113 "encoder/decoder cross-attention "
@@ -164,135 +127,47 @@ def forward(
164127 """Forward pass with Pallas attention.
165128
166129 Args:
167- query: shape = [batch_size, seq_len, num_heads * head_size]
168- key: shape = [batch_size, seq_len, num_kv_heads * head_size]
169- value: shape = [batch_size, seq_len, num_kv_heads * head_size]
170- kv_cache[0] = [num_kv_heads, num_blocks, block_size, head_size]
171- kv_cache[1] = [num_kv_heads, num_blocks, block_size, head_size]
172- NOTE: kv_cache[0] and kv_cache[1] will be an empty tensor
173- with shape [0] for profiling run.
130+ query: shape = [num_tokens, num_heads * head_size]
131+ key: shape = [num_tokens, num_kv_heads * head_size]
132+ value: shape = [num_tokens, num_kv_heads * head_size]
133+ kv_cache = ([num_kv_heads, num_blocks, block_size, head_size],
134+ [num_kv_heads, num_blocks, block_size, head_size])
174135 attn_metadata: Metadata for attention.
175136 Returns:
176- shape = [batch_size, seq_len , num_heads * head_size]
137+ shape = [num_tokens , num_heads * head_size]
177138 """
178-
179- if attn_metadata is None :
139+ # For determine_available_memory case.
140+ if kv_cache [ 0 ]. numel () == 0 :
180141 if output is None :
181142 output = torch .ones_like (query )
182143 return output
183144
184145 assert layer ._k_scale_float == 1.0 and layer ._v_scale_float == 1.0
185- batch_size , seq_len , hidden_size = query .shape
186- query = query .view (batch_size , seq_len , self .num_heads , self .head_size )
187- key = key .view (batch_size , seq_len , self .num_kv_heads , self .head_size )
188- value = value .view (batch_size , seq_len , self .num_kv_heads ,
189- self .head_size )
146+ num_tokens , hidden_size = query .shape
147+ query = query .view (num_tokens , self .num_heads , self .head_size )
148+ key = key .view (num_tokens , self .num_kv_heads , self .head_size )
149+ value = value .view (num_tokens , self .num_kv_heads , self .head_size )
190150
151+ key_cache , value_cache = kv_cache
191152 if kv_cache [0 ].numel () > 0 :
192153 slot_mapping = attn_metadata .slot_mapping
193- key_cache , value_cache = kv_cache
194154 write_to_kv_cache (key , value , key_cache , value_cache , slot_mapping )
195155
196156 query = query * self .scale
197- if attn_metadata .num_prefills > 0 :
198- if attn_metadata .block_tables is None :
199- # Prefill without paged KV cache.
200- assert seq_len % 16 == 0 , (
201- "Pallas FlashAttention kernel requires seq_len to be a "
202- f"multiple of 16 but got { seq_len } " )
203-
204- # Handle GQA/MQA.
205- if self .num_kv_heads != self .num_heads :
206- key = key .repeat_interleave (self .num_queries_per_kv ,
207- dim = - 2 )
208- key = key .view (batch_size , seq_len , self .num_heads ,
209- self .head_size )
210- value = value .repeat_interleave (self .num_queries_per_kv ,
211- dim = - 2 )
212- value = value .view (batch_size , seq_len , self .num_heads ,
213- self .head_size )
214- # FlashAttention kernel requires the input shape to be
215- # [batch_size, num_heads, seq_len, d_model]
216- # while the input is [batch_size, seq_len, num_heads, d_model].
217- # Permute the input to match the required format.
218- output = torch .ops .xla .flash_attention (
219- query .permute (0 , 2 , 1 , 3 ),
220- key .permute (0 , 2 , 1 , 3 ),
221- value .permute (0 , 2 , 1 , 3 ),
222- True ,
223- )
224- output = output .permute (0 , 2 , 1 , 3 )
225- else :
226- # Prefill with paged KV cache.
227- # TODO(woosuk): Tune the below knobs.
228- num_kv_pages_per_compute_block = 16
229- num_queries_per_compute_block = 16
230- assert seq_len % num_queries_per_compute_block == 0
231- output = torch .ops .xla .multi_queries_paged_attention (
232- query ,
233- key_cache ,
234- value_cache ,
235- attn_metadata .context_lens ,
236- attn_metadata .block_tables ,
237- attn_metadata .effective_query_lens ,
238- num_kv_pages_per_compute_block ,
239- num_queries_per_compute_block ,
240- use_kernel = True ,
241- )
242- else :
243- # Decoding run.
244- assert kv_cache [0 ].numel () > 0
245- query = query .squeeze (dim = 1 )
246- pages_per_compute_block = 16 # TODO(woosuk): Tune this value.
247-
248- assert attn_metadata .block_tables is not None
249- assert attn_metadata .context_lens is not None
250- # NOTE(woosuk): The PagedAttention Pallas kernel stores the entire
251- # block table in SMEM. Therefore, if the block table is too large,
252- # the kernel compilation will fail. To avoid this, we split the
253- # batch dimension into smaller chunks and run the kernel multiple
254- # times.
255- MAX_SMEM_USAGE = 512 * 1024
256- size_per_seq = 4 * attn_metadata .block_tables .shape [1 ]
257- max_num_seq = MAX_SMEM_USAGE // size_per_seq
258-
259- if batch_size <= max_num_seq :
260- output = paged_attention (
261- query ,
262- key_cache ,
263- value_cache ,
264- attn_metadata .context_lens ,
265- attn_metadata .block_tables ,
266- pages_per_compute_block ,
267- self .megacore_mode ,
268- )
269- else :
270- chunk_size = max_num_seq
271- # Make sure the chunk size is a multiple of 2.
272- chunk_size = chunk_size // 2 * 2
273- num_chunks = (batch_size + chunk_size - 1 ) // chunk_size
274-
275- output = torch .empty_like (query )
276- for chunk_idx in range (num_chunks ):
277- chunk_start = chunk_idx * chunk_size
278- chunk_end = chunk_start + chunk_size
279- # NOTE(woosuk): We skip this line because it causes Dynamo
280- # compilation error. Instead, we rely on the slice operation
281- # to handle the out-of-bound case.
282- # chunk_end = min(chunk_end, batch_size)
283- chunk_output = paged_attention (
284- query [chunk_start :chunk_end ],
285- key_cache ,
286- value_cache ,
287- attn_metadata .context_lens [chunk_start :chunk_end ],
288- attn_metadata .block_tables [chunk_start :chunk_end ],
289- pages_per_compute_block ,
290- self .megacore_mode ,
291- )
292- output [chunk_start :chunk_end ] = chunk_output
157+ output = torch .ops .xla .ragged_paged_attention (
158+ query ,
159+ key_cache ,
160+ value_cache ,
161+ attn_metadata .context_lens ,
162+ attn_metadata .block_tables ,
163+ attn_metadata .query_start_loc ,
164+ attn_metadata .num_seqs ,
165+ num_kv_pages_per_block = NUM_KV_PAGES_PER_BLOCK ,
166+ num_queries_per_block = NUM_QUERIES_PER_BLOCK ,
167+ use_kernel = False ,
168+ )
293169
294- # Reshape the output tensor.
295- return output .reshape (batch_size , seq_len , hidden_size )
170+ return output .reshape (num_tokens , hidden_size )
296171
297172
298173def write_to_kv_cache (
@@ -302,52 +177,21 @@ def write_to_kv_cache(
302177 value_cache : torch .Tensor ,
303178 slot_mapping : torch .Tensor ,
304179) -> None :
180+ """ Write the key and values to the KV cache.
181+
182+ Args:
183+ key: shape = [num_tokens, num_kv_heads, head_size]
184+ value: shape = [num_tokens, num_kv_heads, head_size]
185+ k_cache = [num_kv_heads, num_blocks, block_size, head_size]
186+ v_cache = [num_kv_heads, num_blocks, block_size, head_size]
187+
188+ """
305189 torch .ops .xla .dynamo_set_buffer_donor_ (key_cache , True )
306190 torch .ops .xla .dynamo_set_buffer_donor_ (value_cache , True )
307191
308- key = key .flatten (0 , 2 )
309- value = value .flatten (0 , 2 )
192+ key = key .flatten (0 , 1 )
193+ value = value .flatten (0 , 1 )
310194 key_cache = key_cache .flatten (0 , 2 )
311195 value_cache = value_cache .flatten (0 , 2 )
312196 key_cache .index_copy_ (0 , slot_mapping , key )
313197 value_cache .index_copy_ (0 , slot_mapping , value )
314-
315-
316- def paged_attention (
317- query : torch .Tensor ,
318- key_cache : torch .Tensor ,
319- value_cache : torch .Tensor ,
320- context_lens : torch .Tensor ,
321- block_tables : torch .Tensor ,
322- pages_per_compute_block : int ,
323- megacore_mode : Optional [str ],
324- ) -> torch .Tensor :
325- batch_size = query .shape [0 ]
326- if megacore_mode == "batch" and batch_size % 2 != 0 :
327- megacore_mode = None
328- else :
329- megacore_mode = megacore_mode
330-
331- # NOTE(woosuk): A temporary workaround to avoid the error:
332- # "xla::paged_attention() Expected a value of type 'str' for
333- # argument 'megacore_mode' but instead found type 'NoneType'."
334- if megacore_mode is not None :
335- output = torch .ops .xla .paged_attention (
336- query ,
337- key_cache ,
338- value_cache ,
339- context_lens ,
340- block_tables ,
341- pages_per_compute_block ,
342- megacore_mode = megacore_mode ,
343- )
344- else :
345- output = torch .ops .xla .paged_attention (
346- query ,
347- key_cache ,
348- value_cache ,
349- context_lens ,
350- block_tables ,
351- pages_per_compute_block ,
352- )
353- return output
0 commit comments