Skip to content

Commit cbb9669

Browse files
vanbasten23mgoin
authored andcommitted
[V1][TPU] Integrate the new ragged paged attention kernel with vLLM v1 on TPU (vllm-project#13379)
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: mgoin <mgoin64@gmail.com> Signed-off-by: Louis Ulmer <ulmerlouis@gmail.com>
1 parent 81238b0 commit cbb9669

File tree

6 files changed

+353
-905
lines changed

6 files changed

+353
-905
lines changed

requirements-tpu.txt

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,8 @@ ray[default]
1717
--find-links https://storage.googleapis.com/libtpu-releases/index.html
1818
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
1919
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
20-
torch @ https://download.pytorch.org/whl/nightly/cpu/torch-2.6.0.dev20241216%2Bcpu-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
21-
torch @ https://download.pytorch.org/whl/nightly/cpu/torch-2.6.0.dev20241216%2Bcpu-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
22-
torch @ https://download.pytorch.org/whl/nightly/cpu/torch-2.6.0.dev20241216%2Bcpu-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
23-
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
24-
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
25-
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
20+
21+
torch==2.7.0.dev20250226+cpu
22+
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250226+cxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
23+
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250226+cxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
24+
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250226+cxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"

vllm/v1/attention/backends/pallas.py

Lines changed: 62 additions & 218 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,16 @@
44
from typing import Any, Dict, List, Optional, Tuple, Type
55

66
import torch
7-
import torch_xla.experimental.custom_kernel # Required to register custom ops.
7+
# Required to register custom ops.
8+
import torch_xla.experimental.custom_kernel # noqa: F401
89

910
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
10-
AttentionLayer,
11-
AttentionMetadata, AttentionType)
11+
AttentionLayer, AttentionType)
1212
from vllm.attention.backends.utils import CommonAttentionState
1313

14+
NUM_QUERIES_PER_BLOCK = 16
15+
NUM_KV_PAGES_PER_BLOCK = 128
16+
1417

1518
class PallasAttentionBackend(AttentionBackend):
1619

@@ -47,47 +50,23 @@ def swap_blocks(
4750
) -> None:
4851
raise RuntimeError("swap_blocks is not used for the TPU backend.")
4952

50-
@torch.compile(backend="openxla")
51-
@staticmethod
52-
def copy_blocks(
53-
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
54-
src_to_dists: Tuple[torch.Tensor, torch.Tensor],
55-
) -> None:
56-
src_indices, dst_indices = src_to_dists
57-
for k_cache, v_cache in kv_caches:
58-
torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True)
59-
k_cache[:, dst_indices] = k_cache[:, src_indices]
60-
torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True)
61-
v_cache[:, dst_indices] = v_cache[:, src_indices]
62-
6353

6454
@dataclass
65-
class PallasMetadata(AttentionMetadata):
66-
67-
# Currently, input sequences can only contain all prefills
68-
# or all decoding.
69-
block_tables: Optional[torch.Tensor] = None
70-
context_lens: Optional[torch.Tensor] = None
71-
effective_query_lens: Optional[torch.Tensor] = None
72-
73-
@property
74-
def prefill_metadata(self) -> Optional["PallasMetadata"]:
75-
if self.num_prefills == 0:
76-
return None
77-
78-
assert self.num_decode_tokens == 0
79-
return self
80-
81-
@property
82-
def decode_metadata(self) -> Optional["PallasMetadata"]:
83-
if self.num_decode_tokens == 0:
84-
return None
85-
86-
assert self.num_prefills == 0
87-
assert self.num_prefill_tokens == 0
88-
assert self.block_tables is not None
89-
assert self.context_lens is not None
90-
return self
55+
class PallasMetadata:
56+
# NOTE(sang): Definition of context_len, query_len, and seq_len.
57+
# |---------- N-1 iteration --------|
58+
# |---------------- N iteration ---------------------|
59+
# |- tokenA -|......................|-- newTokens ---|
60+
# |---------- context_len ----------|
61+
# |-------------------- seq_len ---------------------|
62+
# |-- query_len ---|
63+
64+
# Used in the PallasAttentionBackendImpl
65+
slot_mapping: torch.Tensor
66+
block_tables: torch.Tensor
67+
context_lens: torch.Tensor
68+
query_start_loc: torch.Tensor
69+
num_seqs: int
9170

9271

9372
class PallasAttentionBackendImpl(AttentionImpl):
@@ -105,10 +84,13 @@ def __init__(
10584
logits_soft_cap: Optional[float] = None,
10685
attn_type: str = AttentionType.DECODER,
10786
) -> None:
87+
if blocksparse_params is not None:
88+
raise ValueError("Paged attention Pallas kernel does "
89+
"not support block-sparse attention.")
10890
self.num_heads = num_heads
10991
self.head_size = head_size
11092
self.scale = float(scale)
111-
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
93+
self.num_kv_heads = num_kv_heads
11294

11395
assert self.num_heads % self.num_kv_heads == 0
11496
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
@@ -126,25 +108,6 @@ def __init__(
126108
raise NotImplementedError(
127109
"Attention logits soft-capping is not supported.")
128110

129-
if torch_xla.tpu.version() < 4:
130-
raise NotImplementedError("TPU version must be 4 or higher.")
131-
132-
self.megacore_mode = None
133-
tpu_env = torch_xla.tpu.get_tpu_env()
134-
tpu_type = (tpu_env.get("ACCELERATOR_TYPE", None)
135-
or tpu_env.get("TYPE", None)
136-
or tpu_env.get("TPU_ACCELERATOR_TYPE", None))
137-
assert tpu_type is not None
138-
tpu_type = tpu_type.lower()
139-
140-
if (("lite" not in tpu_type) and ("v6" not in tpu_type)):
141-
if self.num_kv_heads % 2 == 0:
142-
self.megacore_mode = "kv_head"
143-
else:
144-
# NOTE(woosuk): If the batch size is not a multiple of 2, the
145-
# megacore mode will be None.
146-
self.megacore_mode = "batch"
147-
148111
if attn_type != AttentionType.DECODER:
149112
raise NotImplementedError("Encoder self-attention and "
150113
"encoder/decoder cross-attention "
@@ -164,135 +127,47 @@ def forward(
164127
"""Forward pass with Pallas attention.
165128
166129
Args:
167-
query: shape = [batch_size, seq_len, num_heads * head_size]
168-
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
169-
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
170-
kv_cache[0] = [num_kv_heads, num_blocks, block_size, head_size]
171-
kv_cache[1] = [num_kv_heads, num_blocks, block_size, head_size]
172-
NOTE: kv_cache[0] and kv_cache[1] will be an empty tensor
173-
with shape [0] for profiling run.
130+
query: shape = [num_tokens, num_heads * head_size]
131+
key: shape = [num_tokens, num_kv_heads * head_size]
132+
value: shape = [num_tokens, num_kv_heads * head_size]
133+
kv_cache = ([num_kv_heads, num_blocks, block_size, head_size],
134+
[num_kv_heads, num_blocks, block_size, head_size])
174135
attn_metadata: Metadata for attention.
175136
Returns:
176-
shape = [batch_size, seq_len, num_heads * head_size]
137+
shape = [num_tokens, num_heads * head_size]
177138
"""
178-
179-
if attn_metadata is None:
139+
# For determine_available_memory case.
140+
if kv_cache[0].numel() == 0:
180141
if output is None:
181142
output = torch.ones_like(query)
182143
return output
183144

184145
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
185-
batch_size, seq_len, hidden_size = query.shape
186-
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
187-
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
188-
value = value.view(batch_size, seq_len, self.num_kv_heads,
189-
self.head_size)
146+
num_tokens, hidden_size = query.shape
147+
query = query.view(num_tokens, self.num_heads, self.head_size)
148+
key = key.view(num_tokens, self.num_kv_heads, self.head_size)
149+
value = value.view(num_tokens, self.num_kv_heads, self.head_size)
190150

151+
key_cache, value_cache = kv_cache
191152
if kv_cache[0].numel() > 0:
192153
slot_mapping = attn_metadata.slot_mapping
193-
key_cache, value_cache = kv_cache
194154
write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping)
195155

196156
query = query * self.scale
197-
if attn_metadata.num_prefills > 0:
198-
if attn_metadata.block_tables is None:
199-
# Prefill without paged KV cache.
200-
assert seq_len % 16 == 0, (
201-
"Pallas FlashAttention kernel requires seq_len to be a "
202-
f"multiple of 16 but got {seq_len}")
203-
204-
# Handle GQA/MQA.
205-
if self.num_kv_heads != self.num_heads:
206-
key = key.repeat_interleave(self.num_queries_per_kv,
207-
dim=-2)
208-
key = key.view(batch_size, seq_len, self.num_heads,
209-
self.head_size)
210-
value = value.repeat_interleave(self.num_queries_per_kv,
211-
dim=-2)
212-
value = value.view(batch_size, seq_len, self.num_heads,
213-
self.head_size)
214-
# FlashAttention kernel requires the input shape to be
215-
# [batch_size, num_heads, seq_len, d_model]
216-
# while the input is [batch_size, seq_len, num_heads, d_model].
217-
# Permute the input to match the required format.
218-
output = torch.ops.xla.flash_attention(
219-
query.permute(0, 2, 1, 3),
220-
key.permute(0, 2, 1, 3),
221-
value.permute(0, 2, 1, 3),
222-
True,
223-
)
224-
output = output.permute(0, 2, 1, 3)
225-
else:
226-
# Prefill with paged KV cache.
227-
# TODO(woosuk): Tune the below knobs.
228-
num_kv_pages_per_compute_block = 16
229-
num_queries_per_compute_block = 16
230-
assert seq_len % num_queries_per_compute_block == 0
231-
output = torch.ops.xla.multi_queries_paged_attention(
232-
query,
233-
key_cache,
234-
value_cache,
235-
attn_metadata.context_lens,
236-
attn_metadata.block_tables,
237-
attn_metadata.effective_query_lens,
238-
num_kv_pages_per_compute_block,
239-
num_queries_per_compute_block,
240-
use_kernel=True,
241-
)
242-
else:
243-
# Decoding run.
244-
assert kv_cache[0].numel() > 0
245-
query = query.squeeze(dim=1)
246-
pages_per_compute_block = 16 # TODO(woosuk): Tune this value.
247-
248-
assert attn_metadata.block_tables is not None
249-
assert attn_metadata.context_lens is not None
250-
# NOTE(woosuk): The PagedAttention Pallas kernel stores the entire
251-
# block table in SMEM. Therefore, if the block table is too large,
252-
# the kernel compilation will fail. To avoid this, we split the
253-
# batch dimension into smaller chunks and run the kernel multiple
254-
# times.
255-
MAX_SMEM_USAGE = 512 * 1024
256-
size_per_seq = 4 * attn_metadata.block_tables.shape[1]
257-
max_num_seq = MAX_SMEM_USAGE // size_per_seq
258-
259-
if batch_size <= max_num_seq:
260-
output = paged_attention(
261-
query,
262-
key_cache,
263-
value_cache,
264-
attn_metadata.context_lens,
265-
attn_metadata.block_tables,
266-
pages_per_compute_block,
267-
self.megacore_mode,
268-
)
269-
else:
270-
chunk_size = max_num_seq
271-
# Make sure the chunk size is a multiple of 2.
272-
chunk_size = chunk_size // 2 * 2
273-
num_chunks = (batch_size + chunk_size - 1) // chunk_size
274-
275-
output = torch.empty_like(query)
276-
for chunk_idx in range(num_chunks):
277-
chunk_start = chunk_idx * chunk_size
278-
chunk_end = chunk_start + chunk_size
279-
# NOTE(woosuk): We skip this line because it causes Dynamo
280-
# compilation error. Instead, we rely on the slice operation
281-
# to handle the out-of-bound case.
282-
# chunk_end = min(chunk_end, batch_size)
283-
chunk_output = paged_attention(
284-
query[chunk_start:chunk_end],
285-
key_cache,
286-
value_cache,
287-
attn_metadata.context_lens[chunk_start:chunk_end],
288-
attn_metadata.block_tables[chunk_start:chunk_end],
289-
pages_per_compute_block,
290-
self.megacore_mode,
291-
)
292-
output[chunk_start:chunk_end] = chunk_output
157+
output = torch.ops.xla.ragged_paged_attention(
158+
query,
159+
key_cache,
160+
value_cache,
161+
attn_metadata.context_lens,
162+
attn_metadata.block_tables,
163+
attn_metadata.query_start_loc,
164+
attn_metadata.num_seqs,
165+
num_kv_pages_per_block=NUM_KV_PAGES_PER_BLOCK,
166+
num_queries_per_block=NUM_QUERIES_PER_BLOCK,
167+
use_kernel=False,
168+
)
293169

294-
# Reshape the output tensor.
295-
return output.reshape(batch_size, seq_len, hidden_size)
170+
return output.reshape(num_tokens, hidden_size)
296171

297172

298173
def write_to_kv_cache(
@@ -302,52 +177,21 @@ def write_to_kv_cache(
302177
value_cache: torch.Tensor,
303178
slot_mapping: torch.Tensor,
304179
) -> None:
180+
""" Write the key and values to the KV cache.
181+
182+
Args:
183+
key: shape = [num_tokens, num_kv_heads, head_size]
184+
value: shape = [num_tokens, num_kv_heads, head_size]
185+
k_cache = [num_kv_heads, num_blocks, block_size, head_size]
186+
v_cache = [num_kv_heads, num_blocks, block_size, head_size]
187+
188+
"""
305189
torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True)
306190
torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True)
307191

308-
key = key.flatten(0, 2)
309-
value = value.flatten(0, 2)
192+
key = key.flatten(0, 1)
193+
value = value.flatten(0, 1)
310194
key_cache = key_cache.flatten(0, 2)
311195
value_cache = value_cache.flatten(0, 2)
312196
key_cache.index_copy_(0, slot_mapping, key)
313197
value_cache.index_copy_(0, slot_mapping, value)
314-
315-
316-
def paged_attention(
317-
query: torch.Tensor,
318-
key_cache: torch.Tensor,
319-
value_cache: torch.Tensor,
320-
context_lens: torch.Tensor,
321-
block_tables: torch.Tensor,
322-
pages_per_compute_block: int,
323-
megacore_mode: Optional[str],
324-
) -> torch.Tensor:
325-
batch_size = query.shape[0]
326-
if megacore_mode == "batch" and batch_size % 2 != 0:
327-
megacore_mode = None
328-
else:
329-
megacore_mode = megacore_mode
330-
331-
# NOTE(woosuk): A temporary workaround to avoid the error:
332-
# "xla::paged_attention() Expected a value of type 'str' for
333-
# argument 'megacore_mode' but instead found type 'NoneType'."
334-
if megacore_mode is not None:
335-
output = torch.ops.xla.paged_attention(
336-
query,
337-
key_cache,
338-
value_cache,
339-
context_lens,
340-
block_tables,
341-
pages_per_compute_block,
342-
megacore_mode=megacore_mode,
343-
)
344-
else:
345-
output = torch.ops.xla.paged_attention(
346-
query,
347-
key_cache,
348-
value_cache,
349-
context_lens,
350-
block_tables,
351-
pages_per_compute_block,
352-
)
353-
return output

vllm/v1/outputs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,4 +79,4 @@ class ModelRunnerOutput:
7979
# [prompt_len, num_prompt_logprobs]
8080
# [prompt_len, num_prompt_logprobs]
8181
# [prompt_len]
82-
prompt_logprobs_dict: Dict[str, LogprobsTensors]
82+
prompt_logprobs_dict: Dict[str, Optional[LogprobsTensors]]

vllm/v1/worker/gpu_model_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,12 +1071,12 @@ def _get_prompt_logprobs_dict(
10711071
self,
10721072
hidden_states: torch.Tensor,
10731073
scheduler_output: "SchedulerOutput",
1074-
) -> Dict[str, LogprobsTensors]:
1074+
) -> Dict[str, Optional[LogprobsTensors]]:
10751075
num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs
10761076
if not num_prompt_logprobs_dict:
10771077
return {}
10781078

1079-
prompt_logprobs_dict: Dict[str, LogprobsTensors] = {}
1079+
prompt_logprobs_dict: Dict[str, Optional[LogprobsTensors]] = {}
10801080

10811081
# Since prompt logprobs are a rare feature, prioritize simple,
10821082
# maintainable loop over optimal performance.

0 commit comments

Comments
 (0)