Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions requirements-tpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@ ray[default]
--find-links https://storage.googleapis.com/libtpu-releases/index.html
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
torch @ https://download.pytorch.org/whl/nightly/cpu/torch-2.6.0.dev20241216%2Bcpu-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch @ https://download.pytorch.org/whl/nightly/cpu/torch-2.6.0.dev20241216%2Bcpu-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch @ https://download.pytorch.org/whl/nightly/cpu/torch-2.6.0.dev20241216%2Bcpu-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"

torch==2.7.0.dev20250226+cpu
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250226+cxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250226+cxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250226+cxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
280 changes: 62 additions & 218 deletions vllm/v1/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@
from typing import Any, Dict, List, Optional, Tuple, Type

import torch
import torch_xla.experimental.custom_kernel # Required to register custom ops.
# Required to register custom ops.
import torch_xla.experimental.custom_kernel # noqa: F401

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType)
AttentionLayer, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState

NUM_QUERIES_PER_BLOCK = 16
NUM_KV_PAGES_PER_BLOCK = 128


class PallasAttentionBackend(AttentionBackend):

Expand Down Expand Up @@ -47,47 +50,23 @@ def swap_blocks(
) -> None:
raise RuntimeError("swap_blocks is not used for the TPU backend.")

@torch.compile(backend="openxla")
@staticmethod
def copy_blocks(
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
src_to_dists: Tuple[torch.Tensor, torch.Tensor],
) -> None:
src_indices, dst_indices = src_to_dists
for k_cache, v_cache in kv_caches:
torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True)
k_cache[:, dst_indices] = k_cache[:, src_indices]
torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True)
v_cache[:, dst_indices] = v_cache[:, src_indices]


@dataclass
class PallasMetadata(AttentionMetadata):

# Currently, input sequences can only contain all prefills
# or all decoding.
block_tables: Optional[torch.Tensor] = None
context_lens: Optional[torch.Tensor] = None
effective_query_lens: Optional[torch.Tensor] = None

@property
def prefill_metadata(self) -> Optional["PallasMetadata"]:
if self.num_prefills == 0:
return None

assert self.num_decode_tokens == 0
return self

@property
def decode_metadata(self) -> Optional["PallasMetadata"]:
if self.num_decode_tokens == 0:
return None

assert self.num_prefills == 0
assert self.num_prefill_tokens == 0
assert self.block_tables is not None
assert self.context_lens is not None
return self
class PallasMetadata:
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|

# Used in the PallasAttentionBackendImpl
slot_mapping: torch.Tensor
block_tables: torch.Tensor
context_lens: torch.Tensor
query_start_loc: torch.Tensor
num_seqs: int


class PallasAttentionBackendImpl(AttentionImpl):
Expand All @@ -105,10 +84,13 @@ def __init__(
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
) -> None:
if blocksparse_params is not None:
raise ValueError("Paged attention Pallas kernel does "
"not support block-sparse attention.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.num_kv_heads = num_kv_heads

assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
Expand All @@ -126,25 +108,6 @@ def __init__(
raise NotImplementedError(
"Attention logits soft-capping is not supported.")

if torch_xla.tpu.version() < 4:
raise NotImplementedError("TPU version must be 4 or higher.")

self.megacore_mode = None
tpu_env = torch_xla.tpu.get_tpu_env()
tpu_type = (tpu_env.get("ACCELERATOR_TYPE", None)
or tpu_env.get("TYPE", None)
or tpu_env.get("TPU_ACCELERATOR_TYPE", None))
assert tpu_type is not None
tpu_type = tpu_type.lower()

if (("lite" not in tpu_type) and ("v6" not in tpu_type)):
if self.num_kv_heads % 2 == 0:
self.megacore_mode = "kv_head"
else:
# NOTE(woosuk): If the batch size is not a multiple of 2, the
# megacore mode will be None.
self.megacore_mode = "batch"

if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
Expand All @@ -164,135 +127,47 @@ def forward(
"""Forward pass with Pallas attention.

Args:
query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
kv_cache[0] = [num_kv_heads, num_blocks, block_size, head_size]
kv_cache[1] = [num_kv_heads, num_blocks, block_size, head_size]
NOTE: kv_cache[0] and kv_cache[1] will be an empty tensor
with shape [0] for profiling run.
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = ([num_kv_heads, num_blocks, block_size, head_size],
[num_kv_heads, num_blocks, block_size, head_size])
attn_metadata: Metadata for attention.
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
shape = [num_tokens, num_heads * head_size]
"""

if attn_metadata is None:
# For determine_available_memory case.
if kv_cache[0].numel() == 0:
if output is None:
output = torch.ones_like(query)
return output

assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
batch_size, seq_len, hidden_size = query.shape
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
value = value.view(batch_size, seq_len, self.num_kv_heads,
self.head_size)
num_tokens, hidden_size = query.shape
query = query.view(num_tokens, self.num_heads, self.head_size)
key = key.view(num_tokens, self.num_kv_heads, self.head_size)
value = value.view(num_tokens, self.num_kv_heads, self.head_size)

key_cache, value_cache = kv_cache
if kv_cache[0].numel() > 0:
slot_mapping = attn_metadata.slot_mapping
key_cache, value_cache = kv_cache
write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping)

query = query * self.scale
if attn_metadata.num_prefills > 0:
if attn_metadata.block_tables is None:
# Prefill without paged KV cache.
assert seq_len % 16 == 0, (
"Pallas FlashAttention kernel requires seq_len to be a "
f"multiple of 16 but got {seq_len}")

# Handle GQA/MQA.
if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv,
dim=-2)
key = key.view(batch_size, seq_len, self.num_heads,
self.head_size)
value = value.repeat_interleave(self.num_queries_per_kv,
dim=-2)
value = value.view(batch_size, seq_len, self.num_heads,
self.head_size)
# FlashAttention kernel requires the input shape to be
# [batch_size, num_heads, seq_len, d_model]
# while the input is [batch_size, seq_len, num_heads, d_model].
# Permute the input to match the required format.
output = torch.ops.xla.flash_attention(
query.permute(0, 2, 1, 3),
key.permute(0, 2, 1, 3),
value.permute(0, 2, 1, 3),
True,
)
output = output.permute(0, 2, 1, 3)
else:
# Prefill with paged KV cache.
# TODO(woosuk): Tune the below knobs.
num_kv_pages_per_compute_block = 16
num_queries_per_compute_block = 16
assert seq_len % num_queries_per_compute_block == 0
output = torch.ops.xla.multi_queries_paged_attention(
query,
key_cache,
value_cache,
attn_metadata.context_lens,
attn_metadata.block_tables,
attn_metadata.effective_query_lens,
num_kv_pages_per_compute_block,
num_queries_per_compute_block,
use_kernel=True,
)
else:
# Decoding run.
assert kv_cache[0].numel() > 0
query = query.squeeze(dim=1)
pages_per_compute_block = 16 # TODO(woosuk): Tune this value.

assert attn_metadata.block_tables is not None
assert attn_metadata.context_lens is not None
# NOTE(woosuk): The PagedAttention Pallas kernel stores the entire
# block table in SMEM. Therefore, if the block table is too large,
# the kernel compilation will fail. To avoid this, we split the
# batch dimension into smaller chunks and run the kernel multiple
# times.
MAX_SMEM_USAGE = 512 * 1024
size_per_seq = 4 * attn_metadata.block_tables.shape[1]
max_num_seq = MAX_SMEM_USAGE // size_per_seq

if batch_size <= max_num_seq:
output = paged_attention(
query,
key_cache,
value_cache,
attn_metadata.context_lens,
attn_metadata.block_tables,
pages_per_compute_block,
self.megacore_mode,
)
else:
chunk_size = max_num_seq
# Make sure the chunk size is a multiple of 2.
chunk_size = chunk_size // 2 * 2
num_chunks = (batch_size + chunk_size - 1) // chunk_size

output = torch.empty_like(query)
for chunk_idx in range(num_chunks):
chunk_start = chunk_idx * chunk_size
chunk_end = chunk_start + chunk_size
# NOTE(woosuk): We skip this line because it causes Dynamo
# compilation error. Instead, we rely on the slice operation
# to handle the out-of-bound case.
# chunk_end = min(chunk_end, batch_size)
chunk_output = paged_attention(
query[chunk_start:chunk_end],
key_cache,
value_cache,
attn_metadata.context_lens[chunk_start:chunk_end],
attn_metadata.block_tables[chunk_start:chunk_end],
pages_per_compute_block,
self.megacore_mode,
)
output[chunk_start:chunk_end] = chunk_output
output = torch.ops.xla.ragged_paged_attention(
query,
key_cache,
value_cache,
attn_metadata.context_lens,
attn_metadata.block_tables,
attn_metadata.query_start_loc,
attn_metadata.num_seqs,
num_kv_pages_per_block=NUM_KV_PAGES_PER_BLOCK,
num_queries_per_block=NUM_QUERIES_PER_BLOCK,
use_kernel=False,
)

# Reshape the output tensor.
return output.reshape(batch_size, seq_len, hidden_size)
return output.reshape(num_tokens, hidden_size)


def write_to_kv_cache(
Expand All @@ -302,52 +177,21 @@ def write_to_kv_cache(
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
) -> None:
""" Write the key and values to the KV cache.

Args:
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
k_cache = [num_kv_heads, num_blocks, block_size, head_size]
v_cache = [num_kv_heads, num_blocks, block_size, head_size]

"""
torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True)
torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True)

key = key.flatten(0, 2)
value = value.flatten(0, 2)
key = key.flatten(0, 1)
value = value.flatten(0, 1)
key_cache = key_cache.flatten(0, 2)
value_cache = value_cache.flatten(0, 2)
key_cache.index_copy_(0, slot_mapping, key)
value_cache.index_copy_(0, slot_mapping, value)


def paged_attention(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
context_lens: torch.Tensor,
block_tables: torch.Tensor,
pages_per_compute_block: int,
megacore_mode: Optional[str],
) -> torch.Tensor:
batch_size = query.shape[0]
if megacore_mode == "batch" and batch_size % 2 != 0:
megacore_mode = None
else:
megacore_mode = megacore_mode

# NOTE(woosuk): A temporary workaround to avoid the error:
# "xla::paged_attention() Expected a value of type 'str' for
# argument 'megacore_mode' but instead found type 'NoneType'."
if megacore_mode is not None:
output = torch.ops.xla.paged_attention(
query,
key_cache,
value_cache,
context_lens,
block_tables,
pages_per_compute_block,
megacore_mode=megacore_mode,
)
else:
output = torch.ops.xla.paged_attention(
query,
key_cache,
value_cache,
context_lens,
block_tables,
pages_per_compute_block,
)
return output
2 changes: 1 addition & 1 deletion vllm/v1/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,4 @@ class ModelRunnerOutput:
# [prompt_len, num_prompt_logprobs]
# [prompt_len, num_prompt_logprobs]
# [prompt_len]
prompt_logprobs_dict: Dict[str, LogprobsTensors]
prompt_logprobs_dict: Dict[str, Optional[LogprobsTensors]]
4 changes: 2 additions & 2 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,12 +1071,12 @@ def _get_prompt_logprobs_dict(
self,
hidden_states: torch.Tensor,
scheduler_output: "SchedulerOutput",
) -> Dict[str, LogprobsTensors]:
) -> Dict[str, Optional[LogprobsTensors]]:
num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs
if not num_prompt_logprobs_dict:
return {}

prompt_logprobs_dict: Dict[str, LogprobsTensors] = {}
prompt_logprobs_dict: Dict[str, Optional[LogprobsTensors]] = {}

# Since prompt logprobs are a rare feature, prioritize simple,
# maintainable loop over optimal performance.
Expand Down
Loading