Skip to content

Commit

Permalink
bs/seq bucketing for prompt and decode (#33)
Browse files Browse the repository at this point in the history
* Bucketing/Warmup WIP

* Cleanup

* Revert "Fix model_output_idx on HPU (#27)"

This reverts commit 90dfa92.

* Rework selected_token_indices fix to also work with block_size padding

* Simple prompt attention POC

* Remove cumsum

* MQA/GQA support for simple prompt_attention

* Cleanup

* Fix typo

* Restore profiling runs
  • Loading branch information
madamczykhabana authored May 21, 2024
1 parent 2664659 commit ce1670b
Show file tree
Hide file tree
Showing 5 changed files with 225 additions and 763 deletions.
57 changes: 22 additions & 35 deletions vllm/attention/backends/habana_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
###############################################################################

import importlib
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Type

import torch
import math
import vllm.hpu.xops as xops
from vllm.hpu.attn_bias import (AttentionBias,
BlockDiagonalCausalMask,
LowerTriangularMaskWithTensorBias)

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
Expand All @@ -18,7 +17,6 @@
from vllm.attention.ops.habana_paged_attn import (HabanaPagedAttention,
HabanaPagedAttentionMetadata)
from vllm.logger import init_logger
from vllm.utils import is_hip

logger = init_logger(__name__)

Expand Down Expand Up @@ -119,11 +117,11 @@ def __post_init__(self):
class HabanaAttentionImpl(AttentionImpl):
"""
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prefill_tokens ----------------->|
|<--------------- num_prefill_tokens ----------------->|
|<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
Otherwise, the layout is as follows:
|<----------------- num_decode_tokens ------------------>|
Otherwise, the layout is as follows:
|<----------------- num_decode_tokens ------------------>|
|<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
Generation tokens can contain padding when cuda-graph is used.
Expand Down Expand Up @@ -196,48 +194,37 @@ def forward(
HabanaPagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache,
attn_metadata.slot_mapping,
attn_metadata.kv_cache_dtype,
attn_metadata.kv_cache_dtype,
attn_metadata.prefill_metadata is not None)

if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
# normal attention.
# block tables are empty if the prompt does not have a cached
# prefix.
if self.num_kv_heads != self.num_heads:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# heads.
# TODO(woosuk): Use MQA/GQA kernels for higher performance.
query = query.view(query.shape[0], self.num_kv_heads,
self.num_queries_per_kv,
query.shape[-1])
key = key[:, :,
None, :].expand(key.shape[0], self.num_kv_heads,
self.num_queries_per_kv,
key.shape[-1])
value = value[:, :,
None, :].expand(value.shape[0],
self.num_kv_heads,
self.num_queries_per_kv,
value.shape[-1])

# TODO: move this outside of model
if prefill_meta.attn_bias is None:
if self.alibi_slopes is None:
attn_bias = BlockDiagonalCausalMask.from_seqlens(
[seq_len] * batch_size)
lens = torch.tensor(attn_metadata.prefill_metadata.seq_lens, device=query.device, dtype=torch.int32)
len_mask = (torch.arange(0, seq_len, device=query.device, dtype=torch.int32)
.view(1, seq_len)
.ge(lens.unsqueeze(-1))
.view(batch_size, 1, 1, seq_len))
causal_mask = torch.triu(
torch.ones((batch_size, 1, seq_len, seq_len), device=query.device, dtype=torch.bool),
diagonal=1
)
mask = causal_mask.logical_or(len_mask)
attn_bias = (torch.zeros_like(mask, dtype=query.dtype)
.masked_fill_(mask, -math.inf))
if self.sliding_window is not None:
attn_bias = attn_bias.make_local_attention(
self.sliding_window)
raise NotImplementedError("Sliding window is not supported on HPU")
prefill_meta.attn_bias = attn_bias
else:
prefill_meta.attn_bias = _make_alibi_bias(
self.alibi_slopes, self.num_kv_heads, batch_size,
seq_len, query.dtype)
query_shape = (batch_size, seq_len, self.num_kv_heads, self.num_queries_per_kv, self.head_size) if self.num_kv_heads != self.num_heads else (batch_size, seq_len, self.num_heads, self.head_size)
kv_shape = (batch_size, seq_len_kv, self.num_kv_heads, self.num_queries_per_kv, self.head_size) if self.num_kv_heads != self.num_heads else (batch_size, seq_len_kv, self.num_kv_heads, self.head_size)
out = xops.memory_efficient_attention_forward(
query_shape = (batch_size, seq_len, self.num_heads, self.head_size)
kv_shape = (batch_size, seq_len_kv, self.num_kv_heads, self.head_size)
out = xops.prompt_attention(
query.view(query_shape),
key.view(kv_shape),
value.view(kv_shape),
Expand Down
85 changes: 30 additions & 55 deletions vllm/hpu/xops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,62 +5,37 @@
# LICENSE file in the root directory of this source tree.
###############################################################################

import habana_frameworks.torch as htorch
import torch
import torch.nn.functional as F
from typing import List, Optional, Tuple, Union
from .attn_bias import AttentionBias, BlockDiagonalCausalMask
from typing import Optional

try:
from habana_frameworks.torch.hpex.kernels import FusedSDPA
except ImportError:
print("Not using HPU fused scaled dot-product attention kernel.")
FusedSDPA = None
import vllm.hpu.utils

def memory_efficient_attention_forward(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[torch.Tensor] = None,
p: float = 0.0,
scale: Optional[float] = None,
) -> torch.Tensor:
assert attn_bias is not None, "Attention mask is required for prompt processing"
dim = query.dim()
is_causal = isinstance(attn_bias, BlockDiagonalCausalMask)
if FusedSDPA and (is_causal or attn_bias is None):
bs = query.shape[0]
seq_len_q = query.shape[1]
seq_len_kv = key.shape[1]
heads = query.shape[-2] if dim != 5 else query.shape[-3]
attn_groups = 1 if dim != 5 else query.shape[-2]
head_dim = query.shape[-1]
if dim == 4:
# [bs, seq_len, 1, heads, head_dim] -> [bs, heads, seq_len, head_dim]
query = query.reshape(bs, seq_len_q, heads, head_dim).permute(0, 2, 1, 3)
key = key.reshape(bs, seq_len_kv, heads, head_dim).permute(0, 2, 1, 3)
value = value.reshape(bs, seq_len_kv, heads, head_dim).permute(0, 2, 1, 3)
elif dim == 5:
# [bs, seq_len, heads, attn_groups, head_dim] -> [bs, heads, attn_groups, seq_len, head_dim]
query = query.reshape(bs, seq_len_q, heads, attn_groups, head_dim).permute(0, 2, 3, 1, 4)
key = key.reshape(bs, seq_len_kv, heads, attn_groups, head_dim).permute(0, 2, 3, 1, 4)
value = value.reshape(bs, seq_len_kv, heads, attn_groups, head_dim).permute(0, 2, 3, 1, 4)
else:
raise ValueError(f"Unsupported attention dimension: {dim}")

import habana_frameworks.torch.hpu as ht
with ht.sdp_kernel(enable_recompute=False): # (flash_attention_recompute and q_len == 1)):
out = FusedSDPA.apply(
query, key, value, None, p, is_causal, scale
)
htorch.core.mark_step()
if dim == 4:
# [bs, heads, seq_len, head_dim] -> [bs, seq_len, heads, head_dim]
out = out.permute(0, 2, 1, 3).reshape(bs, seq_len_q, heads, head_dim)
elif dim == 5:
# [bs, heads, attn_groups, seq_len, head_dim] -> [bs, seq_len, heads, attn_groups, head_dim]
out = out.permute(0, 3, 1, 2, 4).reshape(bs, seq_len_q, heads, attn_groups, head_dim)
else:
raise NotImplementedError(f'Only FusedSDPA causal or non-masked attention is supported.\nFusedSDPA support: {FusedSDPA is not None}\nis_causal: {is_causal}\nmask_present: {attn_bias is not None}')

return out
@vllm.hpu.utils.with_mark_steps
def prompt_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[torch.Tensor] = None,
p: float = 0.0,
scale: Optional[float] = None,
) -> torch.Tensor:
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
query_heads = query.size(1)
kv_heads = key.size(1)
if query_heads != kv_heads:
query = query.unflatten(1, (kv_heads, -1))
key = key.unflatten(1, (kv_heads, 1))
value = value.unflatten(1, (kv_heads, 1))
attn_bias = attn_bias.unsqueeze(2)
attn_weights = torch.matmul(query * scale, key.transpose(-1, -2))
if attn_bias is not None:
attn_weights.add_(attn_bias)
attn_weights = torch.softmax(attn_weights, dim=-1)
attn_weights = torch.matmul(attn_weights, value)
if query_heads != kv_heads:
attn_weights = attn_weights.flatten(1, 2)
attn_weights = attn_weights.transpose(1, 2)
return attn_weights
9 changes: 0 additions & 9 deletions vllm/model_executor/sampling_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,12 +192,6 @@ def _prepare_seq_groups(
# Total number of prompts from given sequence groups.
num_prompts = 0

# FIXME: On HPU prompts are right-padded. We need to take that into account
# when updating model_output_idx
if is_hpu() and len(seq_lens) > 0:
assert seq_lens == query_lens, 'Prompt chunking is not yet supported on HPU!'
max_seq_len = max(seq_lens)

for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_ids = list(seq_group_metadata.seq_data.keys())
sampling_params = seq_group_metadata.sampling_params
Expand Down Expand Up @@ -225,12 +219,10 @@ def _prepare_seq_groups(
prompt_logprob_len = (query_len - num_prefill_sample
if do_sample else query_len)
sample_len = num_prefill_sample if do_sample else 0
padding_len = 0 if not is_hpu() else max_seq_len - seq_len
else:
# Decode
prompt_logprob_len = 0
sample_len = len(seq_ids) if do_sample else 0
padding_len = 0

# Update indices to select from the model output.
"""
Expand All @@ -249,7 +241,6 @@ def _prepare_seq_groups(
selected_token_indices.extend(
range(model_output_idx, model_output_idx + sample_len))
model_output_idx += sample_len
model_output_idx += padding_len

# We now find indices for logprob computation and sampling.
"""
Expand Down
Loading

0 comments on commit ce1670b

Please sign in to comment.