Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 20 additions & 12 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,16 @@ def get_kv_cache_shape(
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
src_to_dsts: torch.Tensor,
) -> None:
HPUPagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
HPUPagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dsts)

@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
src_to_dsts: torch.Tensor,
) -> None:
HPUPagedAttention.copy_blocks(kv_caches, src_to_dists)
HPUPagedAttention.copy_blocks(kv_caches, src_to_dsts)


@dataclass
Expand All @@ -77,6 +77,7 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
is_prompt: bool
attn_bias: Optional[torch.Tensor]
seq_lens_tensor: Optional[torch.Tensor]
context_lens_tensor: Optional[torch.Tensor]


class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
Expand Down Expand Up @@ -198,8 +199,7 @@ def forward(
key_cache = None
value_cache = None
if attn_metadata.is_prompt and self.attn_type \
is not AttentionType.ENCODER_ONLY \
and attn_metadata.block_list is None:
is not AttentionType.ENCODER_ONLY:
key = key.unflatten(0, (block_indices.size(0), -1))
value = value.unflatten(0, (block_indices.size(0), -1))
if kv_cache is not None and isinstance(kv_cache, tuple):
Expand Down Expand Up @@ -229,6 +229,9 @@ def forward(
attn_bias = attn_bias.tile((1, self.num_kv_heads, 1, 1))
attn_bias.add_(position_bias)

block_list = attn_metadata.block_list if attn_metadata \
and attn_metadata.block_list is not None else None

out = ops.prompt_attention(
impl=self.prefill_impl,
query=query.view(query_shape),
Expand All @@ -237,23 +240,25 @@ def forward(
is_causal=True,
attn_bias=attn_bias,
valid_seq_lengths=attn_metadata.seq_lens_tensor,
**self.common_attention_args())
**self.common_attention_args(block_list, key_cache,
value_cache))
output = out.reshape(batch_size, seq_len, hidden_size)
else:
# Decoding run.
output = HPUPagedAttention.forward_decode(
query=query,
key_cache=key_cache,
value_cache=value_cache,
block_list=attn_metadata.block_list,
block_mapping=attn_metadata.block_mapping,
block_bias=attn_metadata.attn_bias,
block_groups=attn_metadata.block_groups,
**self.common_attention_args())
**self.common_attention_args(attn_metadata.block_list,
key_cache, value_cache))
# Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size)

def common_attention_args(self):
def common_attention_args(self,
block_list=None,
key_cache=None,
value_cache=None):
fsdpa_op = self.fused_scaled_dot_product_attention.apply \
if self.fused_scaled_dot_product_attention is not None else None
return {
Expand All @@ -266,6 +271,9 @@ def common_attention_args(self):
'keys_fetch_func': self.k_cache.fetch_from_cache,
'values_fetch_func': self.v_cache.fetch_from_cache,
'softmax_op': self.softmax,
'block_list': block_list,
'key_cache': key_cache,
'value_cache': value_cache,
}


Expand Down
36 changes: 9 additions & 27 deletions vllm/attention/ops/hpu_paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
###############################################################################

from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from typing import List, Optional, Tuple

import torch
from vllm_hpu_extension import cache_ops, ops
Expand Down Expand Up @@ -63,43 +63,25 @@ def write_to_paged_cache(key: torch.Tensor, value: torch.Tensor,
def forward_decode(**kwargs) -> torch.Tensor:
return ops.flat_pa(**kwargs)

@staticmethod
def forward_prefix(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
subquery_start_loc: torch.Tensor,
seq_lens_tensor: torch.Tensor,
context_lens: torch.Tensor,
max_query_len: int,
alibi_slopes: Optional[torch.Tensor],
sliding_window: Optional[int],
) -> torch.Tensor:
raise NotImplementedError(
"forward_prefix is not implemented for HPUPagedAttention")

@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
src_kv_cache: Tuple[torch.Tensor, torch.Tensor],
dst_kv_cache: Tuple[torch.Tensor, torch.Tensor],
src_to_dsts: torch.Tensor,
) -> None:
src_key_cache = src_kv_cache[0]
dst_key_cache = dst_kv_cache[0]
cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dsts)

src_value_cache = src_kv_cache[1]
dst_value_cache = dst_kv_cache[1]
cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dsts)

@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
src_to_dsts: torch.Tensor,
) -> None:
key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches]
cache_ops.copy_blocks(key_caches, value_caches, src_to_dists)
cache_ops.copy_blocks(key_caches, value_caches, src_to_dsts)
134 changes: 117 additions & 17 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import os
import time
from array import array
from enum import IntEnum
from enum import Enum, IntEnum
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple,
Optional, Set, Tuple, Type, TypeVar, Union)

Expand Down Expand Up @@ -75,6 +75,12 @@
DUMMY_TOKEN_ID = -1


class PhaseType(Enum):
PREFILL = 'prefill'
PREFIX_PREFILL = 'prefix_prefill'
DECODE = 'decode'


def subtuple(obj: object,
typename: str,
to_copy: List[str],
Expand Down Expand Up @@ -213,20 +219,40 @@ def _compile_region(self, model, name, module):

def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device,
dtype):
prefill_metadata = attn_metadata
if prefill_metadata is None or self.prefill_use_fusedsdpa:
if (attn_metadata is None
or (self.prefill_use_fusedsdpa \
and attn_metadata.block_list is None)
or not attn_metadata.is_prompt):
return attn_metadata

prefill_metadata = attn_metadata

seq_lens_t = prefill_metadata.seq_lens_tensor
context_lens_t = prefill_metadata.context_lens_tensor
query_lens_t = seq_lens_t - context_lens_t

block_list = attn_metadata.block_list
max_context_len = (block_list.size(-1) //
batch_size if block_list is not None else 0)
max_context_len = max_context_len * self.block_size
past_mask = torch.arange(0,
max_context_len,
dtype=torch.int32,
device=device)
past_mask = (past_mask.view(1, -1).expand(batch_size, -1).ge(
context_lens_t.view(-1, 1)).view(batch_size, 1, -1).expand(
batch_size, seq_len, -1).view(batch_size, 1, seq_len, -1))

len_mask = (torch.arange(0, seq_len, device=device,
dtype=torch.int32).view(1, seq_len).ge(
seq_lens_t.unsqueeze(-1)).view(
query_lens_t.unsqueeze(-1)).view(
batch_size, 1, 1, seq_len))
causal_mask = torch.triu(torch.ones((batch_size, 1, seq_len, seq_len),
device=device,
dtype=torch.bool),
diagonal=1)
mask = causal_mask.logical_or(len_mask)
mask = torch.concat((past_mask, mask), dim=-1)
attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(
mask, -math.inf))
attn_metadata = prefill_metadata._replace(attn_bias=attn_bias)
Expand Down Expand Up @@ -517,6 +543,11 @@ def __init__(
False, self.max_model_len)
self.graphed_buckets: Set[Any] = set()
self._set_gc_threshold()
if self.vllm_config.cache_config.enable_prefix_caching:
os.environ.setdefault("VLLM_CONTIGUOUS_PA", "False")
assert os.environ.get(
"VLLM_CONTIGUOUS_PA",
"").lower() != "true", "Contiguous PA doesn't support APC"
self.use_contiguous_pa = envs.VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH

# For multi-step scheduling
Expand Down Expand Up @@ -702,6 +733,10 @@ def _prepare_prompt(
computed_block_nums) > 0 and self.sliding_window is None:
# Prefix is not supported with sliding_window
context_len = len(computed_block_nums) * self.block_size
if context_len == seq_len \
and self.vllm_config.cache_config.enable_prefix_caching:
# Fully cached prompt - compute only last token
context_len = context_len - 1
prompt_tokens = prompt_tokens[context_len:]
prefix_block_tables.append(computed_block_nums)
elif self.scheduler_config.chunked_prefill_enabled:
Expand Down Expand Up @@ -779,12 +814,33 @@ def _prepare_prompt(
if lora_id > 0:
lora_requests.add(seq_group_metadata.lora_request)

lora_index_mapping += [lora_id] * (max_prompt_len - context_len)
lora_index_mapping += [lora_id] * max_prompt_len
lora_prompt_mapping.extend(
[lora_id] *
(max_prompt_len - context_len
(max_prompt_len
if seq_group_metadata.sampling_params.prompt_logprobs else 1))

if any(context_lens):
assert not self.scheduler_config.chunked_prefill_enabled
# prefix caching

max_num_block = max(len(bt) for bt in prefix_block_tables)
prefix_block_list = list(
itertools.chain.from_iterable(
bt if len(bt) == max_num_block else bt +
([_PAD_BLOCK_ID] * (max_num_block - len(bt)))
for bt in prefix_block_tables))

pad_len = len(prefix_block_list)
prefix_block_list = pad_list(prefix_block_list, pad_len,
_PAD_BLOCK_ID)

prefix_block_list_tensor = torch.tensor(prefix_block_list,
dtype=torch.long,
device=self.device)
else:
prefix_block_list_tensor = None

input_tokens = make_tensor_with_pad(input_tokens,
max_len=max_prompt_len,
pad=0,
Expand All @@ -807,18 +863,23 @@ def _prepare_prompt(
dtype=torch.long,
device=self.device)

context_lens_tensor = torch.tensor(context_lens,
dtype=torch.long,
device=self.device)

block_indices, block_offsets = precompute_indices_and_offsets(
self.block_size, slot_mapping, True)
attn_metadata = self.attn_backend.make_metadata(
is_prompt=True,
block_list=None,
block_list=prefix_block_list_tensor,
block_mapping=None,
block_usage=None,
block_indices=block_indices,
block_offsets=block_offsets,
block_groups=None,
attn_bias=None,
seq_lens_tensor=seq_lens_tensor,
context_lens_tensor=context_lens_tensor,
num_prefills=real_num_seqs,
num_prefill_tokens=sum_query_len,
num_decode_tokens=0,
Expand Down Expand Up @@ -987,6 +1048,7 @@ def _prepare_decode(
block_groups=block_groups,
attn_bias=None,
seq_lens_tensor=None,
context_lens_tensor=None,
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=num_decode_tokens,
Expand Down Expand Up @@ -1091,7 +1153,7 @@ def prepare_input_tensors(
# FIXME: We need to adjust selected_token_indices to accommodate
# for padding
max_len = input_tokens.size(1)
paddings = [max_len - s for s in seq_lens]
paddings = [max_len - q for q in query_lens]
paddings = [0] + paddings[:-1]
paddings = list(itertools.accumulate(paddings))
paddings_prompt_logprobs = []
Expand Down Expand Up @@ -1187,9 +1249,17 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object:
# input_hash(123) != input_hash(321)
# input_hash("abc") != input_hash("cba")
attention_metadata = subtuple(metadata, 'TrimmedAttentionMetadata', [
'attn_bias', 'seq_lens_tensor', 'block_list', 'block_mapping',
'block_usage', 'slot_mapping', 'is_prompt', 'block_indices',
'block_offsets', 'block_groups'
'attn_bias',
'seq_lens_tensor',
'context_lens_tensor',
'block_list',
'block_mapping',
'block_usage',
'slot_mapping',
'is_prompt',
'block_indices',
'block_offsets',
'block_groups',
])
return attention_metadata

Expand Down Expand Up @@ -1733,14 +1803,44 @@ def finish_measurements(self):
from neural_compressor.torch.quantization import finalize_calibration
finalize_calibration(self.model.model)

def _check_config(self, batch_size, seq_len, is_prompt, warmup_mode):
cfg = (batch_size, seq_len, is_prompt)
def _num_blocks(self, attn_metadata):
if attn_metadata.block_list is None:
return 0
return attn_metadata.block_list.numel()

def _phase(self, attn_metadata):
phase_type: PhaseType
is_prompt = attn_metadata.is_prompt
is_prefix_prefill = is_prompt and attn_metadata.block_list is not None
if is_prompt and is_prefix_prefill:
phase_type = PhaseType.PREFIX_PREFILL
elif is_prompt and not is_prefix_prefill:
phase_type = PhaseType.PREFILL
elif not is_prompt:
phase_type = PhaseType.DECODE
else:
raise ValueError("Unrecognized pass type, likely due to malformed "
"attention metadata")
return phase_type

def _check_config(self, batch_size, seq_len, attn_metadata, warmup_mode):
is_prefix_caching = self.vllm_config.cache_config.enable_prefix_caching
cfg: Optional[tuple] = None
assert cfg is None, "Configs changed between 2D and 3D"
if is_prefix_caching:
phase = self._phase(attn_metadata)
num_blocks = self._num_blocks(attn_metadata)
cfg = (batch_size, seq_len, num_blocks, phase)
else:
phase = 'prompt' if attn_metadata.is_prompt else 'decode'
cfg = (batch_size, seq_len, phase)
seen = cfg in self.seen_configs
self.seen_configs.add(cfg)
if not seen and not warmup_mode:
phase = 'prompt' if is_prompt else 'decode'
logger.warning("Configuration: (%s, %s, %s) was not warmed-up!",
phase, batch_size, seq_len)
logger.warning("Configuration: %s was not warmed-up!",
(phase.value, batch_size, seq_len,
num_blocks) if is_prefix_caching else
(phase, batch_size, seq_len))

def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: List[int],
is_prompt: bool):
Expand Down Expand Up @@ -1912,7 +2012,7 @@ def execute_model(
batch_size = input_tokens.size(0)
seq_len = self._seq_len(attn_metadata)
use_graphs = self._use_graphs(batch_size, seq_len, is_prompt)
self._check_config(batch_size, seq_len, is_prompt, warmup_mode)
self._check_config(batch_size, seq_len, attn_metadata, warmup_mode)

lora_mask: torch.Tensor = None
lora_logits_mask: torch.Tensor = None
Expand Down