Skip to content

Commit 03fbc3c

Browse files
author
hw_whx
committed
add v0 style schedule into v1 engine
Signed-off-by: hw_whx <wanghexiang7@huawei.com>
1 parent 0be96f6 commit 03fbc3c

File tree

7 files changed

+648
-68
lines changed

7 files changed

+648
-68
lines changed

vllm_ascend/attention/attention.py

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
ModelInputForNPUBuilder, ModelInputForNPUWithSamplingMetadata)
4444

4545

46-
def generate_attn_mask(max_seq_len: int, dtype=torch.float16):
46+
def generate_attn_mask(max_seq_len: int, dtype=torch.float16, mask_value=None):
4747
# Construct lower triangle matrix.
4848
mask_flag = torch.tril(
4949
torch.ones((max_seq_len, max_seq_len),
@@ -52,10 +52,11 @@ def generate_attn_mask(max_seq_len: int, dtype=torch.float16):
5252
mask_flag = ~mask_flag
5353
# Currently for fp16 dtype, the mask value should be set to -inf.
5454
# TODO: Eliminate this part in the future.
55-
if dtype == torch.float16:
56-
mask_value = torch.finfo(torch.float32).min
57-
else:
58-
mask_value = 1
55+
if mask_value is None:
56+
if dtype == torch.float16:
57+
mask_value = torch.finfo(torch.float32).min
58+
else:
59+
mask_value = 1
5960
attn_mask = torch.masked_fill(torch.zeros(size=(max_seq_len, max_seq_len)),
6061
mask_flag, mask_value).to(dtype)
6162
return attn_mask
@@ -66,12 +67,14 @@ class AttentionMaskBuilder:
6667
def __init__(self, attn_mask: torch.Tensor):
6768
self._seq_len_cached = attn_mask.shape[0]
6869
self.attn_mask_cache = attn_mask
70+
self.splitfuse_mask_value = -10000
6971

7072
@classmethod
7173
def initialize_from_len(cls,
7274
max_seq_len: int,
73-
dtype: torch.dtype = torch.float16):
74-
return cls(generate_attn_mask(max_seq_len, dtype))
75+
dtype: torch.dtype = torch.float16,
76+
mask_value: Optional[int] = None):
77+
return cls(generate_attn_mask(max_seq_len, dtype, mask_value))
7578

7679
def update_attn_cache(self, seqlen: int, dtype: torch.dtype,
7780
device: torch.device):
@@ -96,6 +99,42 @@ def get_decode_attn_mask(
9699
self.update_attn_cache(max_s, dtype, device)
97100
return (self.attn_mask_cache.index_select(
98101
0, input_lengths)[:, :max_s].view(-1, 1, max_s).contiguous())
102+
103+
def get_splitfuse_attn_mask(
104+
self,
105+
seq_lens,
106+
query_lens,
107+
position,
108+
dtype,
109+
device,
110+
) -> torch.Tensor:
111+
max_seq_len = max(seq_lens, default=0)
112+
if max_seq_len <= self._seq_len_cached:
113+
self.update_attn_cache(max_seq_len, dtype, device)
114+
return torch.index_select(self.attn_mask_cache,
115+
dim=0,
116+
index=position)[:, :max_seq_len]
117+
total_q_len = sum(query_lens)
118+
attn_mask = torch.zeros((total_q_len, max_seq_len),
119+
dtype=self.vllm_config.model_config.dtype,
120+
device="cpu")
121+
122+
current_row = 0
123+
for i in range(len(query_lens)):
124+
seq_len = seq_lens[i]
125+
q_len = query_lens[i]
126+
context_len = seq_len - q_len
127+
128+
assert context_len >= 0
129+
attn_mask[current_row:current_row + q_len,
130+
context_len:] = self.splitfuse_mask_value
131+
right_tensor = attn_mask[current_row:current_row + q_len,
132+
context_len:seq_len]
133+
right_tensor.mask_fill_(
134+
right_tensor.tril() == self.splitfuse_mask_value, 0)
135+
current_row += q_len
136+
137+
return attn_mask.to(self.device, non_blocking=True)
99138

100139

101140
class AscendAttentionBackend(AttentionBackend):

vllm_ascend/attention/attention_v1.py

Lines changed: 76 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from dataclasses import dataclass
1919
from typing import Any, Dict, List, Optional, Tuple, Type
2020

21+
import numpy as np
22+
2123
import torch
2224
import torch_npu
2325
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
@@ -50,7 +52,7 @@ def get_kv_cache_shape(
5052
num_kv_heads: int,
5153
head_size: int,
5254
) -> Tuple[int, ...]:
53-
return (2, num_blocks, block_size, num_kv_heads * head_size)
55+
return (2, num_blocks, block_size, num_kv_heads, head_size)
5456

5557
@staticmethod
5658
def swap_blocks(
@@ -83,6 +85,21 @@ def copy_blocks(
8385
value_caches[dst_indices] = value_caches[src_indices]
8486

8587

88+
# class AscendAttentionV0StyleBackend(AscendAttentionBackend):
89+
# @staticmethod
90+
# def get_impl_cls() -> Type["AscendAttentionBackendV0StyleImpl"]:
91+
# return AscendAttentionBackendV0StyleImpl
92+
93+
# @staticmethod
94+
# def get_kv_cache_shape(
95+
# num_blocks: int,
96+
# block_size: int,
97+
# num_kv_heads: int,
98+
# head_size: int,
99+
# ) -> Tuple[int, ...]:
100+
# return (2, num_blocks, block_size, num_kv_heads, head_size)
101+
102+
86103
@dataclass
87104
class AscendMetadata:
88105
# (batch_size, max_blocks_per_seq).
@@ -104,6 +121,11 @@ class AscendMetadata:
104121
# FlashAttention has better performance than PageAtttention,
105122
# but it does not support decode requests.
106123
is_only_prefill: bool = False
124+
# These two parameters indicates number of prefill and decode requests scheduled in this step.
125+
# It is used by AscendAttentionBackendPrefillFirstImpl to determine
126+
# whether to perform prefill or decode in prefill first scheduling stragety.
127+
num_prefills: int = 0
128+
num_decodes: int = 0
107129

108130
attn_mask: Optional[torch.Tensor] = None
109131

@@ -140,6 +162,8 @@ def __init__(
140162
assert self.num_heads % self.num_kv_heads == 0
141163
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
142164
self.seq_len_cpu_tensor = None
165+
self.key_cache = None
166+
self.value_cache = None
143167

144168
def forward(
145169
self,
@@ -190,30 +214,64 @@ def forward(
190214
# TODO: Remove this contiguous in the future.
191215
value = value.contiguous()
192216

217+
if kv_cache.numel() > 0:
218+
if self.key_cache is None:
219+
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
220+
slots = attn_metadata.slot_mapping
221+
torch_npu._npu_reshape_and_cache(key=key,
222+
value=value,
223+
key_cache=self.key_cache,
224+
value_cache=self.value_cache,
225+
slot_indices=slots)
226+
193227
if hasattr(layer, 'quant_method'):
194228
# TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata
195229
pass
230+
# V0-Style scheduler situation.
231+
elif attn_metadata.num_prefills is not None:
232+
if attn_metadata.num_prefills > 0:
233+
assert attn_metadata is not None
234+
assert attn_metadata.attn_mask is not None
235+
mask = attn_metadata.attn_mask
236+
self.seq_lens_tensor_cpu = torch.from_numpy(
237+
np.array(attn_metadata.seq_lens).
238+
astype(np.int32))
239+
torch_npu._npu_flash_attention(
240+
query=query,
241+
key=key,
242+
value=value,
243+
mask=mask,
244+
seq_len=self.seq_lens_tensor_cpu,
245+
scale_value=self.scale,
246+
num_heads=self.num_heads,
247+
num_kv_heads=self.num_kv_heads,
248+
out=output)
249+
elif attn_metadata.num_decodes > 0:
250+
# assert self.key_cache is not None
251+
self.seq_lens_tensor_cpu = torch.from_numpy(
252+
np.array(attn_metadata.context_lens).astype(
253+
np.int32))
254+
block_tables = attn_metadata.block_tables
255+
torch_npu._npu_paged_attention(
256+
query=query,
257+
key_cache=self.key_cache,
258+
value_cache=self.value_cache,
259+
num_kv_heads=self.num_kv_heads,
260+
num_heads=self.num_heads,
261+
scale_value=self.scale,
262+
block_table=block_tables,
263+
context_lens=self.seq_lens_tensor_cpu,
264+
out=output)
265+
else:
266+
raise ValueError("At least one of num_prefills and num_decodes should be greater that 0 "
267+
"in v0-style scheduling situation.")
268+
# Normal V1 situation.
196269
else:
197-
if kv_cache.numel() > 0:
198-
key_cache, value_cache = kv_cache[0], kv_cache[1]
199-
num_blocks, block_size, _ = key_cache.shape
200-
key_cache = key_cache.view(num_blocks, block_size,
201-
self.num_kv_heads, self.head_size)
202-
value_cache = value_cache.view(num_blocks, block_size,
203-
self.num_kv_heads,
204-
self.head_size)
205-
slots = attn_metadata.slot_mapping
206-
torch_npu._npu_reshape_and_cache(key=key,
207-
value=value,
208-
key_cache=key_cache,
209-
value_cache=value_cache,
210-
slot_indices=slots)
211-
212270
# use paged attention
213271
torch_npu._npu_paged_attention_splitfuse(
214272
query=query,
215-
key_cache=key_cache,
216-
value_cache=value_cache,
273+
key_cache=self.key_cache,
274+
value_cache=self.value_cache,
217275
mask=attn_metadata.attn_mask,
218276
block_table=attn_metadata.block_tables,
219277
seq_len=attn_metadata.seq_lens,

vllm_ascend/core/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)