Skip to content

Commit 45d59fd

Browse files
committed
[Refactor] Add build_dummy_metadata in attention backend and refactor common metadata
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
1 parent 03be417 commit 45d59fd

File tree

5 files changed

+74
-54
lines changed

5 files changed

+74
-54
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,11 @@
2727
from vllm.config import get_current_vllm_config
2828
from vllm.forward_context import ForwardContext, get_forward_context
2929
from vllm.utils import direct_register_custom_op
30-
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
3130
from vllm.v1.core.sched.output import SchedulerOutput
3231
from vllm.v1.worker.gpu_input_batch import InputBatch
3332

33+
from vllm_ascend.attention.utils import \
34+
AscendCommonAttentionMetadata as CommonAttentionMetadata
3435
from vllm_ascend.ops.attention import vanilla_chunked_prefill
3536
from vllm_ascend.utils import get_graph_params
3637

@@ -163,13 +164,16 @@ def build(self,
163164
block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = (
164165
block_table[:num_reqs])
165166

166-
query_lens = self.runner.query_lens
167+
query_start_loc = common_attn_metadata.query_start_loc
167168
seq_lens = common_attn_metadata.seq_lens
168-
seq_lens_list = self.runner.seq_lens_list
169+
# TODO: Refactor these two param to common metadata in runners,
170+
# preparing for the hybrid KV groups feature
171+
query_lens = common_attn_metadata.query_lens if common_attn_metadata.query_lens is not None else self.runner.query_lens
172+
seq_lens_list = common_attn_metadata.seq_lens_list if common_attn_metadata.seq_lens_list is not None else self.runner.seq_lens_list
173+
169174
slot_mapping = self.runner.slot_mapping[:num_actual_tokens]
170175
attn_mask = self.runner.attn_mask
171176
attn_state = self.runner.attn_state
172-
query_start_loc = common_attn_metadata.query_start_loc
173177

174178
attn_metadata = AscendMetadata(
175179
num_actual_tokens=num_actual_tokens,
@@ -185,6 +189,34 @@ def build(self,
185189
enable_dbo_across_dp=enable_dbo_across_dp)
186190
return attn_metadata
187191

192+
def build_dummy_metadata(self, num_actual_tokens, num_reqs,
193+
num_scheduled_tokens, attn_state):
194+
if attn_state == AscendAttentionState.DecodeOnly:
195+
# NOTE: We only need to pay attention to seq_lens_list and block_table here
196+
common_attn_metadata = CommonAttentionMetadata(seq_lens_list=[2] *
197+
num_reqs)
198+
199+
block_table = self.runner.input_batch.block_table[0].block_table
200+
block_table[:num_reqs, 0] = torch.arange(1,
201+
num_reqs + 1,
202+
device=block_table.device,
203+
dtype=block_table.dtype)
204+
205+
attn_metadata = self.build(
206+
num_reqs=num_reqs,
207+
num_actual_tokens=num_actual_tokens,
208+
max_query_len=num_scheduled_tokens.max(),
209+
common_prefix_len=0,
210+
common_attn_metadata=common_attn_metadata,
211+
)
212+
else:
213+
raise NotImplementedError(
214+
"Currently we only support building dummy metadata for DecodeOnly state"
215+
)
216+
217+
attn_metadata.attn_state = attn_state
218+
return attn_metadata
219+
188220

189221
class AscendAttentionBackendImpl(AttentionImpl):
190222

vllm_ascend/attention/mla_v1.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from vllm_ascend import envs
1717
from vllm_ascend.ascend_config import get_ascend_config
1818
from vllm_ascend.attention.attention_v1 import AscendAttentionState
19+
from vllm_ascend.attention.utils import \
20+
AscendCommonAttentionMetadata as CommonAttentionMetadata
1921
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
2022
from vllm_ascend.multistream.context import get_multistream_comm_context
2123
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
@@ -28,20 +30,6 @@
2830
from vllm.v1.worker.gpu_input_batch import InputBatch
2931

3032

31-
@dataclass
32-
class CommonAttentionMetadata:
33-
"""
34-
Attention metadata attributes that can be shared by layers in different KV
35-
cache groups and thus having different block table.
36-
"""
37-
38-
query_start_loc: torch.Tensor
39-
"""(batch_size + 1,), the start location of each request in query Tensor"""
40-
seq_lens: torch.Tensor
41-
"""(batch_size,), the length of each request including both computed tokens
42-
and newly scheduled tokens"""
43-
44-
4533
class AscendMLABackend(AttentionBackend):
4634

4735
accept_output_buffer: bool = True

vllm_ascend/attention/utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from dataclasses import dataclass
2+
from typing import Optional
3+
4+
import torch
5+
6+
7+
@dataclass
8+
class AscendCommonAttentionMetadata:
9+
"""
10+
Attention metadata attributes that can be shared by layers in different KV
11+
cache groups and thus having different block table.
12+
"""
13+
14+
query_start_loc: Optional[torch.Tensor] = None
15+
"""(batch_size + 1,), the start location of each request in query Tensor"""
16+
seq_lens: Optional[torch.Tensor] = None
17+
"""(batch_size,), the length of each request including both computed tokens
18+
and newly scheduled tokens"""
19+
query_lens: Optional[torch.Tensor] = None
20+
"""(batch_size,), the length of each request including only the newly
21+
scheduled tokens"""
22+
seq_lens_list: Optional[list] = None
23+
"""(num_input_tokens,), note that this is specifically for FIA kernel"""

vllm_ascend/worker/model_runner_v1.py

Lines changed: 11 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@
7777
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
7878
from vllm_ascend.attention.attention import AttentionMaskBuilder
7979
from vllm_ascend.attention.attention_v1 import AscendAttentionState
80-
from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata
80+
from vllm_ascend.attention.utils import \
81+
AscendCommonAttentionMetadata as CommonAttentionMetadata
8182
from vllm_ascend.multistream.ms_split import compute_split_seq_index
8283
from vllm_ascend.platform import NPUPlatform
8384
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
@@ -253,6 +254,9 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
253254
self.slot_mapping = torch.zeros(self.max_num_tokens,
254255
dtype=torch.int32,
255256
device=self.device)
257+
self.query_lens = torch.zeros(self.max_num_reqs,
258+
dtype=torch.int32,
259+
device=self.device)
256260
# None in the first PP rank. The rest are set after load_model.
257261
self.intermediate_tensors: Optional[IntermediateTensors] = None
258262

@@ -1528,6 +1532,7 @@ def _dummy_run(
15281532
skip_attn: bool = True,
15291533
with_prefill: bool = False,
15301534
is_torchair_compile: bool = False,
1535+
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
15311536
) -> torch.Tensor:
15321537
if self.torchair_graph_enabled and not with_prefill:
15331538
num_tokens = self.select_torchair_padded_batch_size(num_tokens)
@@ -1558,43 +1563,13 @@ def _dummy_run(
15581563
elif skip_attn:
15591564
attn_metadata = None
15601565
else:
1561-
query_start_loc = self.query_start_loc[:num_reqs + 1]
1562-
query_start_loc[:] = torch.arange(
1563-
query_start_loc.numel(),
1564-
device=query_start_loc.device,
1565-
dtype=query_start_loc.dtype,
1566-
)
1567-
seq_lens = self.seq_lens_np[:num_reqs]
1568-
seq_lens[:] = seq_lens + 2
1569-
self.seq_lens_list = self.seq_lens_np.tolist()[:num_tokens]
1570-
1571-
common_attn_metadata = CommonAttentionMetadata(
1572-
query_start_loc=query_start_loc, seq_lens=seq_lens)
1573-
1574-
self.query_lens = torch.from_numpy(num_scheduled_tokens)
1575-
1576-
block_table = self.input_batch.block_table[0].block_table
1577-
block_table[:num_reqs, 0] = torch.arange(1,
1578-
num_reqs + 1,
1579-
device=block_table.device,
1580-
dtype=block_table.dtype)
1581-
1582-
self.slot_mapping[:num_tokens] = torch.arange(
1583-
1,
1584-
num_tokens + 1,
1585-
device=self.slot_mapping.device,
1586-
dtype=self.slot_mapping.dtype) * self.block_size + 1
1587-
1588-
attn_metadata = self.attn_metadata_builder.build(
1589-
num_reqs=num_reqs,
1566+
attn_metadata = self.attn_metadata_builder.build_dummy_metadata(
15901567
num_actual_tokens=num_tokens,
1591-
max_query_len=num_tokens,
1592-
common_prefix_len=0,
1593-
common_attn_metadata=common_attn_metadata,
1568+
num_reqs=num_reqs,
1569+
num_scheduled_tokens=num_scheduled_tokens,
1570+
attn_state=attn_state,
15941571
)
15951572

1596-
attn_metadata.attn_state = AscendAttentionState.DecodeOnly
1597-
15981573
with self.maybe_dummy_run_with_lora(self.lora_config,
15991574
num_scheduled_tokens):
16001575
model = self.model
@@ -1977,6 +1952,7 @@ def capture_model(self) -> None:
19771952
# TODO(zzzzwwjj): Check dummy_run with ACL Graph and full graph mode
19781953
with graph_capture(device=self.device):
19791954
skip_attn = not self.vllm_config.compilation_config.full_cuda_graph
1955+
# TODO: Make sure passing attn_state to _dummy_run in the future
19801956
for num_tokens in reversed(self.aclgraph_batch_sizes):
19811957
for _ in range(self.vllm_config.compilation_config.
19821958
cudagraph_num_of_warmups):

vllm_ascend/worker/mtp_proposer_v1.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
from vllm.v1.sample.metadata import SamplingMetadata
99

1010
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
11-
from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata
11+
from vllm_ascend.attention.utils import \
12+
AscendCommonAttentionMetadata as CommonAttentionMetadata
1213
from vllm_ascend.models.deepseek_mtp import CustomDeepSeekMTP
1314

1415

0 commit comments

Comments
 (0)