From 2ede1a2b1da1eb7274e0d7bf1df8e499dd585983 Mon Sep 17 00:00:00 2001 From: Agata Dobrzyniewicz Date: Mon, 5 May 2025 12:07:22 +0300 Subject: [PATCH 1/2] Added HPU support for Automatic Prefix Caching Signed-off-by: Agata Dobrzyniewicz --- vllm/attention/backends/hpu_attn.py | 32 ++++--- vllm/attention/ops/hpu_paged_attn.py | 36 ++----- vllm/worker/hpu_model_runner.py | 134 +++++++++++++++++++++++---- 3 files changed, 146 insertions(+), 56 deletions(-) diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 55a63a81677f..d701c59a234f 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -57,16 +57,16 @@ def get_kv_cache_shape( def swap_blocks( src_kv_cache: torch.Tensor, dst_kv_cache: torch.Tensor, - src_to_dst: Dict[int, int], + src_to_dsts: torch.Tensor, ) -> None: - HPUPagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + HPUPagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dsts) @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], - src_to_dists: Dict[int, List[int]], + src_to_dsts: torch.Tensor, ) -> None: - HPUPagedAttention.copy_blocks(kv_caches, src_to_dists) + HPUPagedAttention.copy_blocks(kv_caches, src_to_dsts) @dataclass @@ -77,6 +77,7 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata): is_prompt: bool attn_bias: Optional[torch.Tensor] seq_lens_tensor: Optional[torch.Tensor] + context_lens_tensor: Optional[torch.Tensor] class HPUAttentionImpl(AttentionImpl, torch.nn.Module): @@ -198,8 +199,7 @@ def forward( key_cache = None value_cache = None if attn_metadata.is_prompt and self.attn_type \ - is not AttentionType.ENCODER_ONLY \ - and attn_metadata.block_list is None: + is not AttentionType.ENCODER_ONLY: key = key.unflatten(0, (block_indices.size(0), -1)) value = value.unflatten(0, (block_indices.size(0), -1)) if kv_cache is not None and isinstance(kv_cache, tuple): @@ -229,6 +229,9 @@ def forward( attn_bias = attn_bias.tile((1, self.num_kv_heads, 1, 1)) attn_bias.add_(position_bias) + block_list = attn_metadata.block_list if attn_metadata \ + and attn_metadata.block_list is not None else None + out = ops.prompt_attention( impl=self.prefill_impl, query=query.view(query_shape), @@ -237,23 +240,25 @@ def forward( is_causal=True, attn_bias=attn_bias, valid_seq_lengths=attn_metadata.seq_lens_tensor, - **self.common_attention_args()) + **self.common_attention_args(block_list, key_cache, + value_cache)) output = out.reshape(batch_size, seq_len, hidden_size) else: # Decoding run. output = HPUPagedAttention.forward_decode( query=query, - key_cache=key_cache, - value_cache=value_cache, - block_list=attn_metadata.block_list, block_mapping=attn_metadata.block_mapping, block_bias=attn_metadata.attn_bias, block_groups=attn_metadata.block_groups, - **self.common_attention_args()) + **self.common_attention_args(attn_metadata.block_list, + key_cache, value_cache)) # Reshape the output tensor. return output.view(batch_size, seq_len, hidden_size) - def common_attention_args(self): + def common_attention_args(self, + block_list=None, + key_cache=None, + value_cache=None): fsdpa_op = self.fused_scaled_dot_product_attention.apply \ if self.fused_scaled_dot_product_attention is not None else None return { @@ -266,6 +271,9 @@ def common_attention_args(self): 'keys_fetch_func': self.k_cache.fetch_from_cache, 'values_fetch_func': self.v_cache.fetch_from_cache, 'softmax_op': self.softmax, + 'block_list': block_list, + 'key_cache': key_cache, + 'value_cache': value_cache, } diff --git a/vllm/attention/ops/hpu_paged_attn.py b/vllm/attention/ops/hpu_paged_attn.py index 1dedd2ffc5fa..9b2bcd5078ac 100644 --- a/vllm/attention/ops/hpu_paged_attn.py +++ b/vllm/attention/ops/hpu_paged_attn.py @@ -5,7 +5,7 @@ ############################################################################### from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple +from typing import List, Optional, Tuple import torch from vllm_hpu_extension import cache_ops, ops @@ -63,43 +63,25 @@ def write_to_paged_cache(key: torch.Tensor, value: torch.Tensor, def forward_decode(**kwargs) -> torch.Tensor: return ops.flat_pa(**kwargs) - @staticmethod - def forward_prefix( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - block_tables: torch.Tensor, - subquery_start_loc: torch.Tensor, - seq_lens_tensor: torch.Tensor, - context_lens: torch.Tensor, - max_query_len: int, - alibi_slopes: Optional[torch.Tensor], - sliding_window: Optional[int], - ) -> torch.Tensor: - raise NotImplementedError( - "forward_prefix is not implemented for HPUPagedAttention") - @staticmethod def swap_blocks( - src_kv_cache: torch.Tensor, - dst_kv_cache: torch.Tensor, - src_to_dst: Dict[int, int], + src_kv_cache: Tuple[torch.Tensor, torch.Tensor], + dst_kv_cache: Tuple[torch.Tensor, torch.Tensor], + src_to_dsts: torch.Tensor, ) -> None: src_key_cache = src_kv_cache[0] dst_key_cache = dst_kv_cache[0] - cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) + cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dsts) src_value_cache = src_kv_cache[1] dst_value_cache = dst_kv_cache[1] - cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) + cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dsts) @staticmethod def copy_blocks( - kv_caches: List[torch.Tensor], - src_to_dists: Dict[int, List[int]], + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + src_to_dsts: torch.Tensor, ) -> None: key_caches = [kv_cache[0] for kv_cache in kv_caches] value_caches = [kv_cache[1] for kv_cache in kv_caches] - cache_ops.copy_blocks(key_caches, value_caches, src_to_dists) + cache_ops.copy_blocks(key_caches, value_caches, src_to_dist) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index e25864349e28..a343e2fedb23 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -14,7 +14,7 @@ import os import time from array import array -from enum import IntEnum +from enum import Enum, IntEnum from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple, Type, TypeVar, Union) @@ -75,6 +75,12 @@ DUMMY_TOKEN_ID = -1 +class PhaseType(Enum): + PREFILL = 'prefill' + PREFIX_PREFILL = 'prefix_prefill' + DECODE = 'decode' + + def subtuple(obj: object, typename: str, to_copy: List[str], @@ -213,20 +219,40 @@ def _compile_region(self, model, name, module): def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device, dtype): - prefill_metadata = attn_metadata - if prefill_metadata is None or self.prefill_use_fusedsdpa: + if (attn_metadata is None + or (self.prefill_use_fusedsdpa \ + and attn_metadata.block_list is None) + or not attn_metadata.is_prompt): return attn_metadata + prefill_metadata = attn_metadata + seq_lens_t = prefill_metadata.seq_lens_tensor + context_lens_t = prefill_metadata.context_lens_tensor + query_lens_t = seq_lens_t - context_lens_t + + block_list = attn_metadata.block_list + max_context_len = (block_list.size(-1) // + batch_size if block_list is not None else 0) + max_context_len = max_context_len * self.block_size + past_mask = torch.arange(0, + max_context_len, + dtype=torch.int32, + device=device) + past_mask = (past_mask.view(1, -1).expand(batch_size, -1).ge( + context_lens_t.view(-1, 1)).view(batch_size, 1, -1).expand( + batch_size, seq_len, -1).view(batch_size, 1, seq_len, -1)) + len_mask = (torch.arange(0, seq_len, device=device, dtype=torch.int32).view(1, seq_len).ge( - seq_lens_t.unsqueeze(-1)).view( + query_lens_t.unsqueeze(-1)).view( batch_size, 1, 1, seq_len)) causal_mask = torch.triu(torch.ones((batch_size, 1, seq_len, seq_len), device=device, dtype=torch.bool), diagonal=1) mask = causal_mask.logical_or(len_mask) + mask = torch.concat((past_mask, mask), dim=-1) attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_( mask, -math.inf)) attn_metadata = prefill_metadata._replace(attn_bias=attn_bias) @@ -517,6 +543,11 @@ def __init__( False, self.max_model_len) self.graphed_buckets: Set[Any] = set() self._set_gc_threshold() + if self.vllm_config.cache_config.enable_prefix_caching: + os.environ.setdefault("VLLM_CONTIGUOUS_PA", "False") + assert os.environ.get( + "VLLM_CONTIGUOUS_PA", + "").lower() != "true", "Contiguous PA doesn't support APC" self.use_contiguous_pa = envs.VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH # For multi-step scheduling @@ -702,6 +733,10 @@ def _prepare_prompt( computed_block_nums) > 0 and self.sliding_window is None: # Prefix is not supported with sliding_window context_len = len(computed_block_nums) * self.block_size + if context_len == seq_len \ + and self.vllm_config.cache_config.enable_prefix_caching: + # Fully cached prompt - compute only last token + context_len = context_len - 1 prompt_tokens = prompt_tokens[context_len:] prefix_block_tables.append(computed_block_nums) elif self.scheduler_config.chunked_prefill_enabled: @@ -779,12 +814,33 @@ def _prepare_prompt( if lora_id > 0: lora_requests.add(seq_group_metadata.lora_request) - lora_index_mapping += [lora_id] * (max_prompt_len - context_len) + lora_index_mapping += [lora_id] * max_prompt_len lora_prompt_mapping.extend( [lora_id] * - (max_prompt_len - context_len + (max_prompt_len if seq_group_metadata.sampling_params.prompt_logprobs else 1)) + if any(context_lens): + assert not self.scheduler_config.chunked_prefill_enabled + # prefix caching + + max_num_block = max(len(bt) for bt in prefix_block_tables) + prefix_block_list = list( + itertools.chain.from_iterable( + bt if len(bt) == max_num_block else bt + + ([_PAD_BLOCK_ID] * (max_num_block - len(bt))) + for bt in prefix_block_tables)) + + pad_len = len(prefix_block_list) + prefix_block_list = pad_list(prefix_block_list, pad_len, + _PAD_BLOCK_ID) + + prefix_block_list_tensor = torch.tensor(prefix_block_list, + dtype=torch.long, + device=self.device) + else: + prefix_block_list_tensor = None + input_tokens = make_tensor_with_pad(input_tokens, max_len=max_prompt_len, pad=0, @@ -807,11 +863,15 @@ def _prepare_prompt( dtype=torch.long, device=self.device) + context_lens_tensor = torch.tensor(context_lens, + dtype=torch.long, + device=self.device) + block_indices, block_offsets = precompute_indices_and_offsets( self.block_size, slot_mapping, True) attn_metadata = self.attn_backend.make_metadata( is_prompt=True, - block_list=None, + block_list=prefix_block_list_tensor, block_mapping=None, block_usage=None, block_indices=block_indices, @@ -819,6 +879,7 @@ def _prepare_prompt( block_groups=None, attn_bias=None, seq_lens_tensor=seq_lens_tensor, + context_lens_tensor=context_lens_tensor, num_prefills=real_num_seqs, num_prefill_tokens=sum_query_len, num_decode_tokens=0, @@ -987,6 +1048,7 @@ def _prepare_decode( block_groups=block_groups, attn_bias=None, seq_lens_tensor=None, + context_lens_tensor=None, num_prefills=0, num_prefill_tokens=0, num_decode_tokens=num_decode_tokens, @@ -1091,7 +1153,7 @@ def prepare_input_tensors( # FIXME: We need to adjust selected_token_indices to accommodate # for padding max_len = input_tokens.size(1) - paddings = [max_len - s for s in seq_lens] + paddings = [max_len - q for q in query_lens] paddings = [0] + paddings[:-1] paddings = list(itertools.accumulate(paddings)) paddings_prompt_logprobs = [] @@ -1187,9 +1249,17 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: # input_hash(123) != input_hash(321) # input_hash("abc") != input_hash("cba") attention_metadata = subtuple(metadata, 'TrimmedAttentionMetadata', [ - 'attn_bias', 'seq_lens_tensor', 'block_list', 'block_mapping', - 'block_usage', 'slot_mapping', 'is_prompt', 'block_indices', - 'block_offsets', 'block_groups' + 'attn_bias', + 'seq_lens_tensor', + 'context_lens_tensor', + 'block_list', + 'block_mapping', + 'block_usage', + 'slot_mapping', + 'is_prompt', + 'block_indices', + 'block_offsets', + 'block_groups', ]) return attention_metadata @@ -1733,14 +1803,44 @@ def finish_measurements(self): from neural_compressor.torch.quantization import finalize_calibration finalize_calibration(self.model.model) - def _check_config(self, batch_size, seq_len, is_prompt, warmup_mode): - cfg = (batch_size, seq_len, is_prompt) + def _num_blocks(self, attn_metadata): + if attn_metadata.block_list is None: + return 0 + return attn_metadata.block_list.numel() + + def _phase(self, attn_metadata): + phase_type: PhaseType + is_prompt = attn_metadata.is_prompt + is_prefix_prefill = is_prompt and attn_metadata.block_list is not None + if is_prompt and is_prefix_prefill: + phase_type = PhaseType.PREFIX_PREFILL + elif is_prompt and not is_prefix_prefill: + phase_type = PhaseType.PREFILL + elif not is_prompt: + phase_type = PhaseType.DECODE + else: + raise ValueError("Unrecognized pass type, likely due to malformed " + "attention metadata") + return phase_type + + def _check_config(self, batch_size, seq_len, attn_metadata, warmup_mode): + is_prefix_caching = self.vllm_config.cache_config.enable_prefix_caching + cfg: Optional[tuple] = None + assert cfg is None, "Configs changed between 2D and 3D" + if is_prefix_caching: + phase = self._phase(attn_metadata) + num_blocks = self._num_blocks(attn_metadata) + cfg = (batch_size, seq_len, num_blocks, phase) + else: + phase = 'prompt' if attn_metadata.is_prompt else 'decode' + cfg = (batch_size, seq_len, phase) seen = cfg in self.seen_configs self.seen_configs.add(cfg) if not seen and not warmup_mode: - phase = 'prompt' if is_prompt else 'decode' - logger.warning("Configuration: (%s, %s, %s) was not warmed-up!", - phase, batch_size, seq_len) + logger.warning("Configuration: %s was not warmed-up!", + (phase.value, batch_size, seq_len, + num_blocks) if is_prefix_caching else + (phase, batch_size, seq_len)) def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: List[int], is_prompt: bool): @@ -1912,7 +2012,7 @@ def execute_model( batch_size = input_tokens.size(0) seq_len = self._seq_len(attn_metadata) use_graphs = self._use_graphs(batch_size, seq_len, is_prompt) - self._check_config(batch_size, seq_len, is_prompt, warmup_mode) + self._check_config(batch_size, seq_len, attn_metadata, warmup_mode) lora_mask: torch.Tensor = None lora_logits_mask: torch.Tensor = None From 2d41a8f1ddeb9b9e961894cea8ebbf9a9809f345 Mon Sep 17 00:00:00 2001 From: Agata Dobrzyniewicz Date: Mon, 5 May 2025 16:34:26 +0300 Subject: [PATCH 2/2] Typo fix Signed-off-by: Agata Dobrzyniewicz --- vllm/attention/ops/hpu_paged_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/ops/hpu_paged_attn.py b/vllm/attention/ops/hpu_paged_attn.py index 9b2bcd5078ac..a97c36338d3c 100644 --- a/vllm/attention/ops/hpu_paged_attn.py +++ b/vllm/attention/ops/hpu_paged_attn.py @@ -84,4 +84,4 @@ def copy_blocks( ) -> None: key_caches = [kv_cache[0] for kv_cache in kv_caches] value_caches = [kv_cache[1] for kv_cache in kv_caches] - cache_ops.copy_blocks(key_caches, value_caches, src_to_dist) + cache_ops.copy_blocks(key_caches, value_caches, src_to_dsts)