Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
30e46ce
Initial implementation
madamczyk-intel Mar 20, 2025
4b52067
Combine all prefills
madamczyk-intel Mar 20, 2025
4d1330a
Cleanup
madamczyk-intel Mar 21, 2025
0a8b692
Set requirements-hpu to dev/madamczyk/merged_prefill_stage0
madamczyk-intel Mar 21, 2025
0b27f93
Cleanup
madamczyk-intel Mar 24, 2025
f774758
Simplify attn_bias
madamczyk-intel Mar 24, 2025
db9f3c5
Merge remote-tracking branch 'origin/habana_main' into dev/madamczyk/…
madamczyk-intel Mar 24, 2025
1e91232
Make ruff happy
madamczyk-intel Mar 24, 2025
c904973
Fix V1
madamczyk-intel Mar 24, 2025
295018d
Remove attn_bias WA
madamczyk-intel Mar 25, 2025
a266501
Merge remote-tracking branch 'origin/habana_main' into dev/madamczyk/…
madamczyk-intel Mar 25, 2025
6d72802
Simplify prompt implementation selection
madamczyk-intel Mar 25, 2025
1b26e86
Make Ruff Happy Again
madamczyk-intel Mar 25, 2025
6bb93ea
Ruff: third time's a charm
madamczyk-intel Mar 25, 2025
d68c8c4
Merge remote-tracking branch 'origin/habana_main' into dev/madamczyk/…
madamczyk-intel Mar 25, 2025
2580e34
Add 'softmax_op'
madamczyk-intel Mar 25, 2025
09bbad1
Merge branch 'habana_main' into dev/madamczyk/merged_prefill_stage0
michalkuligowski Mar 26, 2025
bcb90ea
Set hpu-extension to 708a89a329b721b36d44bd424da3472593a8d42c
madamczyk-intel Mar 28, 2025
0ec3cfd
Merge branch 'habana_main' into dev/madamczyk/merged_prefill_stage0
madamczyk-intel Mar 28, 2025
a2bbd73
Merge branch 'habana_main' into dev/madamczyk/merged_prefill_stage0
michalkuligowski Mar 28, 2025
3db74d9
Merge branch 'habana_main' into dev/madamczyk/merged_prefill_stage0
madamczyk-intel Mar 28, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements-hpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ pandas
tabulate
setuptools>=61
setuptools-scm>=8
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@21284c9
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@708a89a329b721b36d44bd424da3472593a8d42c
162 changes: 57 additions & 105 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
import vllm_hpu_extension.kernels as kernels
import vllm_hpu_extension.ops as ops
from vllm_hpu_extension.flags import enabled_flags
from vllm_hpu_extension.utils import (Matmul, ModuleFusedSDPA, Softmax,
VLLMKVCache)
from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
Expand Down Expand Up @@ -87,7 +86,6 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
cross_slot_mapping: Optional[torch.Tensor] = None
cross_block_mapping: Optional[torch.Tensor] = None
cross_block_groups: Optional[torch.Tensor] = None
cross_block_scales: Optional[torch.Tensor] = None
cross_block_usage: Optional[torch.Tensor] = None
cross_attn_bias: Optional[torch.Tensor] = None

Expand Down Expand Up @@ -134,9 +132,16 @@ def __init__(
self.block2batch_matmul = Matmul()
self.k_cache = VLLMKVCache()
self.v_cache = VLLMKVCache()
HPUFusedSDPA = kernels.fsdpa()
self.fused_scaled_dot_product_attention = None if HPUFusedSDPA is None \
else ModuleFusedSDPA(HPUFusedSDPA)
self.fused_scaled_dot_product_attention = kernels.fsdpa()

self.prefill_impl = 'naive'
if "flex_attention" in enabled_flags():
self.prefill_impl = 'flex'
if "fsdpa" in enabled_flags():
assert alibi_slopes is None, \
'Prefill with FusedSDPA not supported with alibi slopes!'
self.prefill_impl = 'fsdpa'

self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = sliding_window
self.alibi_slopes = alibi_slopes
Expand All @@ -147,20 +152,6 @@ def __init__(
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

self.prefill_use_fusedsdpa = "fsdpa" in enabled_flags()
if self.prefill_use_fusedsdpa:
assert alibi_slopes is None, \
'Prefill with FusedSDPA not supported with alibi slopes!'
try:
from habana_frameworks.torch.hpex.kernels import FusedSDPA
self.fused_scaled_dot_product_attention = ModuleFusedSDPA(
FusedSDPA)
except ImportError:
logger().warning("Could not import HPU FusedSDPA kernel. "
"vLLM will use native implementation.")

self.prefill_use_flex_attention = "flex_attention" in enabled_flags()

suppored_head_sizes = HPUPagedAttention.get_supported_head_sizes()
if head_size not in suppored_head_sizes:
raise ValueError(
Expand Down Expand Up @@ -215,9 +206,11 @@ def forward(
value = value.view(-1, self.num_kv_heads, self.head_size)
block_indices = attn_metadata.block_indices
block_offsets = attn_metadata.block_offsets
key_cache = None
value_cache = None
if attn_metadata.is_prompt and self.attn_type \
is not AttentionType.ENCODER_ONLY \
and attn_metadata.block_list is None:
is not AttentionType.ENCODER_ONLY \
and attn_metadata.block_list is None:
key = key.unflatten(0, (block_indices.size(0), -1))
value = value.unflatten(0, (block_indices.size(0), -1))
if kv_cache is not None and isinstance(kv_cache, tuple):
Expand All @@ -238,46 +231,24 @@ def forward(
kv_shape = (batch_size, seq_len_kv, self.num_kv_heads,
self.head_size)

attn_bias = attn_metadata.attn_bias
if attn_bias is not None and self.alibi_slopes is not None:
position_bias = _make_alibi_bias(self.alibi_slopes,
self.num_kv_heads,
attn_bias.dtype,
attn_bias.shape[-1])
attn_bias = attn_bias.tile((1, self.num_kv_heads, 1, 1))
attn_bias.add_(position_bias)
if attn_metadata is None or attn_metadata.block_list is None:
if (not self.prefill_use_fusedsdpa
and not self.prefill_use_flex_attention):
# TODO: move this outside of model
assert attn_metadata.attn_bias is not None, \
'attn_bias must be set before calling model.forward'
attn_bias = attn_metadata.attn_bias
if self.alibi_slopes is not None:
position_bias = _make_alibi_bias(
self.alibi_slopes, self.num_kv_heads,
attn_bias.dtype, attn_bias.shape[-1])
attn_bias = attn_bias.tile(
(1, self.num_kv_heads, 1, 1))
attn_bias.add_(position_bias)
else:
attn_bias = attn_metadata.attn_bias

if not self.prefill_use_flex_attention:
out = ops.prompt_attention(
query.view(query_shape),
key.view(kv_shape),
value.view(kv_shape),
attn_bias=attn_bias,
p=0.0,
scale=self.scale,
matmul_qk_op=self.matmul_qk,
softmax_op=self.softmax,
matmul_av_op=self.matmul_av,
valid_seq_lengths=attn_metadata.seq_lens_tensor,
fsdpa_op=self.fused_scaled_dot_product_attention
if self.prefill_use_fusedsdpa else None,
)
else:
out = ops.flex_attention(
query.view(query_shape),
key.view(kv_shape),
value.view(kv_shape),
scale=self.scale,
)

out = ops.prompt_attention(
impl=self.prefill_impl,
query=query.view(query_shape),
key=key.view(kv_shape),
value=value.view(kv_shape),
is_causal=True,
attn_bias=attn_bias,
valid_seq_lengths=attn_metadata.seq_lens_tensor,
**self.common_attention_args())
else:
# TODO: enable FusedSDPA
out = HPUPagedAttention.forward_prefix(
Expand All @@ -288,12 +259,7 @@ def forward(
value_cache=value_cache,
block_list=attn_metadata.block_list,
attn_bias=attn_metadata.attn_bias,
scale=self.scale,
matmul_qk_op=self.matmul_qk,
matmul_av_op=self.matmul_av,
softmax_op=self.softmax,
keys_fetch_func=self.k_cache.fetch_from_cache,
values_fetch_func=self.v_cache.fetch_from_cache)
**self.common_attention_args())
output = out.reshape(batch_size, seq_len, hidden_size)
else:
# Decoding run.
Expand All @@ -304,18 +270,26 @@ def forward(
block_list=attn_metadata.block_list,
block_mapping=attn_metadata.block_mapping,
block_bias=attn_metadata.attn_bias,
block_scales=attn_metadata.block_scales,
block_groups=attn_metadata.block_groups,
scale=self.scale,
matmul_qk_op=self.matmul_qk,
matmul_av_op=self.matmul_av,
batch2block_matmul_op=self.batch2block_matmul,
block2batch_matmul_op=self.block2batch_matmul,
keys_fetch_func=self.k_cache.fetch_from_cache,
values_fetch_func=self.v_cache.fetch_from_cache)
**self.common_attention_args())
# Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size)

def common_attention_args(self):
fsdpa_op = self.fused_scaled_dot_product_attention.apply \
if self.fused_scaled_dot_product_attention is not None else None
return {
'scale': self.scale,
'matmul_qk_op': self.matmul_qk,
'matmul_av_op': self.matmul_av,
'batch2block_matmul_op': self.batch2block_matmul,
'block2batch_matmul_op': self.block2batch_matmul,
'fsdpa_op': fsdpa_op,
'keys_fetch_func': self.k_cache.fetch_from_cache,
'values_fetch_func': self.v_cache.fetch_from_cache,
'softmax_op': self.softmax,
}

def forward_encoder_decoder(
self,
query: torch.Tensor,
Expand Down Expand Up @@ -377,34 +351,19 @@ def forward_encoder_decoder(

query_shape = (batch_size, -1, self.num_heads, self.head_size)
kv_shape = (batch_size, -1, self.num_kv_heads, self.head_size)
# Just a workaround, to make ops.prompt_attention go into the
# torch ops assembly path.
# TODO: add new prompt_attention op in vllm_hpu_extension
# which calls FusedSDPA with causal = False.
attn_bias = torch.zeros((batch_size, 1, 1, 1),
device=query.device,
dtype=torch.bool)

out = ops.prompt_attention(
query.view(query_shape),
key.view(kv_shape),
value.view(kv_shape),
attn_bias=attn_bias,
p=0.0,
scale=self.scale,
matmul_qk_op=self.matmul_qk,
softmax_op=self.softmax,
matmul_av_op=self.matmul_av,
fsdpa_op=self.fused_scaled_dot_product_attention
if self.prefill_use_fusedsdpa else None,
)
out = ops.prompt_attention(impl=self.prefill_impl,
query=query.view(query_shape),
key=key.view(kv_shape),
value=value.view(kv_shape),
attn_bias=None,
is_causal=False,
**self.common_attention_args())
output = out.reshape(batch_size, seq_len, hidden_size)
else:
# Enc/dec cross-attention KVs match encoder sequence length;
# cross-attention utilizes special "cross" block tables
block_list = attn_metadata.cross_block_list
block_mapping = attn_metadata.cross_block_mapping
block_scales = attn_metadata.cross_block_scales
block_groups = attn_metadata.cross_block_groups
attn_bias = attn_metadata.cross_attn_bias
# Decoding run.
Expand All @@ -415,15 +374,8 @@ def forward_encoder_decoder(
block_list=block_list,
block_mapping=block_mapping,
block_bias=attn_bias,
block_scales=block_scales,
block_groups=block_groups,
scale=self.scale,
matmul_qk_op=self.matmul_qk,
matmul_av_op=self.matmul_av,
batch2block_matmul_op=self.batch2block_matmul,
block2batch_matmul_op=self.block2batch_matmul,
keys_fetch_func=self.k_cache.fetch_from_cache,
values_fetch_func=self.v_cache.fetch_from_cache)
**self.common_attention_args())
# Reshape the output tensor.
return output.view(batch_size, -1, hidden_size)

Expand Down
1 change: 0 additions & 1 deletion vllm/attention/ops/hpu_paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ class HPUPagedAttentionMetadata:
block_usage: Optional[torch.Tensor]
block_indices: Optional[torch.Tensor]
block_offsets: Optional[torch.Tensor]
block_scales: Optional[torch.Tensor]
block_groups: Optional[torch.Tensor]


Expand Down
3 changes: 0 additions & 3 deletions vllm/v1/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def make_prefill_metadata(cls, seq_lens_tensor, num_prefills,
block_usage=None,
block_indices=None,
block_offsets=None,
block_scales=None,
block_groups=None,
attn_bias=None,
num_decode_tokens=0,
Expand All @@ -68,7 +67,6 @@ def make_cached_prefill_metadata(cls, seq_lens_tensor, context_lens_tensor,
block_usage=None,
block_indices=None,
block_offsets=None,
block_scales=None,
block_groups=None,
attn_bias=None,
num_decode_tokens=0,
Expand All @@ -87,7 +85,6 @@ def make_decode_metadata(cls, block_list, block_usage, block_groups,
block_mapping=None,
block_indices=None,
block_offsets=None,
block_scales=None,
attn_bias=None,
seq_lens_tensor=None,
context_lens_tensor=None,
Expand Down
14 changes: 1 addition & 13 deletions vllm/v1/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import vllm_hpu_extension.environment as environment
from vllm_hpu_extension.bucketing import HPUBucketingContext
from vllm_hpu_extension.flags import enabled_flags
from vllm_hpu_extension.ops import batch2block, block2batch
from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes

from vllm.attention.backends.abstract import AttentionType
Expand Down Expand Up @@ -453,16 +452,6 @@ def _set_block_mapping(self, metadata, batch_size, device, dtype):
attn_bias=attn_bias)
return metadata

def _set_block_scales(self, metadata, device):
block_mapping = metadata.block_mapping
ones = torch.ones((block_mapping.size(0), ),
device=device,
dtype=block_mapping.dtype)
sums = batch2block(block2batch(ones, block_mapping), block_mapping)
block_scales = torch.reciprocal(torch.maximum(ones, sums))
metadata = metadata._replace(block_scales=block_scales)
return metadata

def _set_indices_and_offsets(self, metadata, block_size, is_prompt):
slot_mapping = metadata.slot_mapping.flatten()
indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
Expand All @@ -483,7 +472,6 @@ def _update_metadata(self, attn_metadata, batch_size, seq_len, device,
else:
attn_metadata = self._set_block_mapping(attn_metadata, batch_size,
device, dtype)
attn_metadata = self._set_block_scales(attn_metadata, device)
attn_metadata = self._set_indices_and_offsets(attn_metadata,
self.block_size,
attn_metadata.is_prompt)
Expand Down Expand Up @@ -599,7 +587,7 @@ def trim_attn_metadata(metadata: HPUAttentionMetadataV1) -> object:
attention_metadata = subtuple(metadata, 'TrimmedAttentionMetadata', [
'attn_bias', 'seq_lens_tensor', 'context_lens_tensor', 'block_list',
'block_mapping', 'block_usage', 'slot_mapping', 'is_prompt',
'block_indices', 'block_offsets', 'block_scales', 'block_groups'
'block_indices', 'block_offsets', 'block_groups'
])
return attention_metadata

Expand Down
15 changes: 0 additions & 15 deletions vllm/worker/hpu_enc_dec_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import habana_frameworks.torch as htorch
import torch
from vllm_hpu_extension.ops import batch2block, block2batch

from vllm.attention import AttentionMetadata
from vllm.distributed import broadcast_tensor_dict
Expand Down Expand Up @@ -84,17 +83,6 @@ def _set_cross_block_mapping(self, metadata, batch_size, device, dtype):
cross_attn_bias=cross_attn_bias)
return metadata

def _set_cross_block_scales(self, metadata, device):
cross_block_mapping = metadata.cross_block_mapping
ones = torch.ones((cross_block_mapping.size(0), ),
device=device,
dtype=cross_block_mapping.dtype)
sums = batch2block(block2batch(ones, cross_block_mapping),
cross_block_mapping)
cross_block_scales = torch.reciprocal(torch.maximum(ones, sums))
metadata = metadata._replace(cross_block_scales=cross_block_scales)
return metadata

def _set_cross_indices_and_offsets(self, metadata, block_size):
cross_slot_mapping = metadata.cross_slot_mapping.flatten()
indices = torch.div(cross_slot_mapping,
Expand Down Expand Up @@ -128,7 +116,6 @@ def _update_cross_metadata(self, attn_metadata, batch_size, seq_len,
else:
attn_metadata = self._set_cross_block_mapping(
attn_metadata, batch_size, device, dtype)
attn_metadata = self._set_cross_block_scales(attn_metadata, device)

return attn_metadata

Expand Down Expand Up @@ -526,7 +513,6 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object:
'is_prompt',
'block_indices',
'block_offsets',
'block_scales',
'block_groups',
'num_prefill_tokens',
'num_decode_tokens',
Expand All @@ -540,7 +526,6 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object:
'cross_slot_mapping',
'cross_block_mapping',
'cross_block_groups',
'cross_block_scales',
'cross_block_usage',
'cross_attn_bias',
])
Expand Down
Loading