Skip to content

Commit 8bae971

Browse files
committed
adapt mtp with graph mode in v1
Signed-off-by: whx-sjtu <2952154980@qq.com>
1 parent 908a851 commit 8bae971

File tree

3 files changed

+57
-13
lines changed

3 files changed

+57
-13
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
AttentionMetadata,
99
MLAAttentionImpl)
1010
from vllm.attention.backends.utils import PAD_SLOT_ID
11+
from vllm.config import get_current_vllm_config
1112
from vllm.model_executor.layers.linear import (LinearBase,
1213
UnquantizedLinearMethod)
1314

@@ -83,6 +84,7 @@ class AscendMLADecodeMetadata:
8384
seq_lens: torch.Tensor
8485
max_seq_lens: int
8586
seq_lens_list: list[int]
87+
attn_mask: torch.Tensor
8688

8789

8890
@dataclass
@@ -170,11 +172,13 @@ def reorder_batch(self, input_batch: "InputBatch",
170172

171173
for i, req_id in enumerate(input_batch.req_ids):
172174
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
175+
num_spec_tokens = len(
176+
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
173177
# for now treat 1 scheduled token as "decode" even if its not,
174178
# we should update this to something like < 8 in the future but
175179
# currently the TritonMLA._forward_decode only supports
176180
# num_tokens = 1
177-
if num_tokens == 1:
181+
if num_tokens - num_spec_tokens == 1:
178182
decodes.append(i)
179183
num_decode_tokens += num_tokens
180184
else:
@@ -269,7 +273,8 @@ def build_dummy(self, num_reqs: int,
269273
block_table=block_table,
270274
seq_lens=seq_lens,
271275
seq_lens_list=seq_lens.tolist(),
272-
max_seq_lens=1)
276+
max_seq_lens=1,
277+
attn_mask=self.runner.spec_attn_mask)
273278
return self.metadata_cls( # type: ignore
274279
num_input_tokens=num_actual_tokens,
275280
num_actual_tokens=num_actual_tokens,
@@ -317,7 +322,7 @@ def build(
317322
seq_lens = seq_lens_cpu
318323
max_query_len = query_lens.max().item()
319324
max_seq_lens = seq_lens.max().item()
320-
query_start_loc = None
325+
query_start_loc = common_attn_metadata.query_start_loc
321326

322327
prefill_metadata = None
323328
if self._num_prefills > 0:
@@ -382,7 +387,8 @@ def build(
382387
block_table=block_table,
383388
seq_lens=seq_lens,
384389
seq_lens_list=seq_lens.tolist(),
385-
max_seq_lens=max_seq_lens)
390+
max_seq_lens=max_seq_lens,
391+
attn_mask=self.runner.spec_attn_mask)
386392

387393
return self.metadata_cls( # type: ignore
388394
num_actual_tokens=num_actual_tokens,
@@ -445,6 +451,17 @@ def __init__(
445451

446452
ascend_config = get_ascend_config()
447453
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
454+
# Adapt torch air graph mode with spec decoding.
455+
speculative_config = get_current_vllm_config().speculative_config
456+
self.fia_sparse_mode = 0
457+
self.use_spec_decode = False
458+
# We need to set the sparse_mode of fused_infer_attention op to 3
459+
# in spec decoding scenario in order to pass in attention mask.
460+
if speculative_config is not None:
461+
self.fia_sparse_mode = 3
462+
self.use_spec_decode = True
463+
self.spec_token_num = speculative_config.num_speculative_tokens
464+
assert self.spec_token_num > 0
448465

449466
def _v_up_proj_and_o_proj(self, x):
450467
# Convert from (B, N, L) to (N, B, L)
@@ -646,9 +663,24 @@ def _forward_decode(
646663
dtype=q.dtype,
647664
device=q.device)
648665
if self.running_in_graph:
649-
# TorchAir's shape is [bs, num_heads_per_rank, seq_len, dim]
650-
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
651-
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
666+
# TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
667+
if self.use_spec_decode:
668+
assert num_tokens % self.spec_token_num == 0
669+
q_nope = (q_nope.view(
670+
num_tokens // (self.spec_token_num + 1),
671+
self.spec_token_num + 1,
672+
self.num_heads,
673+
-1,
674+
).transpose(1, 2).contiguous())
675+
q_pe = (q_pe.view(
676+
num_tokens // (self.spec_token_num + 1),
677+
self.spec_token_num + 1,
678+
self.num_heads,
679+
-1,
680+
).transpose(1, 2).contiguous())
681+
else:
682+
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
683+
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
652684
# shape of knope/k_pe for npu graph mode should be:
653685
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
654686
block_size = kv_c_and_k_pe_cache[0].shape[1]
@@ -666,7 +698,8 @@ def _forward_decode(
666698
num_heads=self.num_heads,
667699
num_key_value_heads=self.num_kv_heads,
668700
input_layout="BNSD",
669-
atten_mask=attn_metadata.attn_mask,
701+
atten_mask=attn_metadata.decode.attn_mask, # type:ignore
702+
sparse_mode=self.fia_sparse_mode,
670703
scale=self.scale,
671704
antiquant_mode=0,
672705
antiquant_scale=None,

vllm_ascend/worker/model_runner_v1.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,13 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
196196

197197
# Set up speculative decoding.
198198
self.use_spec_decode = False
199+
self.spec_attn_mask = None
199200
if self.speculative_config:
200201
self.use_spec_decode = True
202+
self.spec_attn_mask = torch.triu(torch.ones(2048,
203+
2048,
204+
dtype=torch.bool),
205+
diagonal=1).to("npu")
201206
if get_pp_group().is_last_rank:
202207
if self.speculative_config.method == "ngram":
203208
self.drafter = NgramProposer(self.vllm_config)
@@ -564,10 +569,13 @@ def _process_reqs(
564569
# Get the number of scheduled tokens for each request.
565570
# TODO: The Python loop can be slow. Optimize.
566571
num_scheduled_tokens = np.empty(num_reqs, dtype=np.int32)
572+
num_valid_tokens = np.empty(num_reqs, dtype=np.int32)
567573
max_num_scheduled_tokens = 0
568574
for i, req_id in enumerate(self.input_batch.req_ids):
569575
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
570576
num_scheduled_tokens[i] = num_tokens
577+
num_valid_tokens[i] = num_tokens - \
578+
len(scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
571579
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
572580
num_tokens)
573581

@@ -615,7 +623,7 @@ def _process_reqs(
615623
if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens):
616624
attn_state = AscendAttentionState.PrefillNoCache
617625
# We assume it is the decode stage, where prefill occurs but only one token is not hit in cache.
618-
elif np.all(num_scheduled_tokens == 1):
626+
elif np.all(num_valid_tokens == 1):
619627
attn_state = AscendAttentionState.DecodeOnly
620628
# splitfuse
621629
elif not ascend_config.ascend_scheduler_config.enabled or self.chunked_prefill_enabled:
@@ -657,14 +665,14 @@ def _process_reqs(
657665
# Add graph_pad_size here
658666
if envs_ascend.VLLM_ENABLE_MC2 or (self.torchair_graph_enabled
659667
and not with_prefill):
660-
batch_size = len(seq_lens)
661668
if self.dp_size > 1:
662669
padded_batch_size = self.select_torchair_padded_batch_size(
663670
max_num_tokens)
664671
else:
665672
padded_batch_size = self.select_torchair_padded_batch_size(
666-
batch_size)
667-
graph_pad_size = padded_batch_size - batch_size
673+
total_num_scheduled_tokens)
674+
graph_pad_size = padded_batch_size - total_num_scheduled_tokens
675+
668676
extra_builder_kwargs['graph_pad_size'] = graph_pad_size
669677

670678
if self.vllm_config.model_config.use_mla:

vllm_ascend/worker/mtp_proposer_v1.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
set_current_vllm_config)
55
from vllm.forward_context import set_forward_context
66
from vllm.model_executor.model_loader import get_model_loader
7-
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
7+
from vllm.model_executor.model_loader.utils import (
8+
process_weights_after_loading, set_default_torch_dtype)
89
from vllm.v1.sample.metadata import SamplingMetadata
910

1011
from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata
@@ -199,6 +200,8 @@ def load_model(self) -> None:
199200
loader.get_all_weights(
200201
self.vllm_config.speculative_config.draft_model_config,
201202
self.model))
203+
process_weights_after_loading(self.model, draft_model_config,
204+
target_device)
202205

203206

204207
# TODO Using torch instead of triton may result in poor performance

0 commit comments

Comments
 (0)