Skip to content

Commit 2979fb2

Browse files
committed
[MTP] follow custom deepseek modeling changes to support graph mode
Signed-off-by: mengwei805 <mengwei25@huawei.com>
1 parent d785e78 commit 2979fb2

File tree

3 files changed

+35
-3
lines changed

3 files changed

+35
-3
lines changed

vllm_ascend/models/deepseek_mtp.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717
# See the License for the specific language governing permissions and
1818
# limitations under the License.
1919

20-
from typing import Optional
20+
from typing import List, Optional
2121

2222
import torch
2323
import torch.nn as nn
2424
from transformers import PretrainedConfig
25+
from vllm.attention.backends.abstract import AttentionMetadata
2526
from vllm.config import CacheConfig, ModelConfig, VllmConfig
2627
from vllm.model_executor.layers.layernorm import RMSNorm
2728
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@@ -34,6 +35,7 @@
3435
SharedHead)
3536
from vllm.model_executor.models.utils import maybe_prefix
3637
from vllm.model_executor.sampling_metadata import SamplingMetadata
38+
from vllm.sequence import IntermediateTensors
3739

3840
from .deepseek_v2 import CustomDeepseekV2DecoderLayer
3941

@@ -69,6 +71,8 @@ def forward(
6971
self,
7072
input_ids: torch.Tensor,
7173
positions: torch.Tensor,
74+
kv_cache: torch.Tensor,
75+
attn_metadata: AttentionMetadata,
7276
previous_hidden_states: torch.Tensor,
7377
inputs_embeds: Optional[torch.Tensor] = None,
7478
spec_step_index: int = 0,
@@ -88,6 +92,8 @@ def forward(
8892

8993
hidden_states, residual = self.mtp_block(positions=positions,
9094
hidden_states=hidden_states,
95+
kv_cache=kv_cache,
96+
attn_metadata=attn_metadata,
9197
residual=None)
9298
hidden_states = residual + hidden_states
9399
return hidden_states
@@ -125,14 +131,20 @@ def forward(
125131
self,
126132
input_ids: torch.Tensor,
127133
positions: torch.Tensor,
134+
kv_caches: torch.Tensor,
135+
attn_metadata: AttentionMetadata,
128136
previous_hidden_states: torch.Tensor,
129137
inputs_embeds: Optional[torch.Tensor] = None,
130138
spec_step_idx: int = 0,
131139
) -> torch.Tensor:
132140
current_step_idx = (spec_step_idx % self.num_mtp_layers)
141+
step_kv_cache = kv_caches[
142+
current_step_idx] if kv_caches is not None else None
133143
return self.layers_list[current_step_idx](
134144
input_ids,
135145
positions,
146+
step_kv_cache,
147+
attn_metadata,
136148
previous_hidden_states,
137149
inputs_embeds,
138150
current_step_idx,
@@ -170,3 +182,19 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
170182
prefix, "model"))
171183

172184
self.sampler = get_sampler()
185+
186+
def forward(
187+
self,
188+
input_ids: torch.Tensor,
189+
positions: torch.Tensor,
190+
kv_caches: Optional[List[torch.Tensor]] = None,
191+
attn_metadata: Optional[AttentionMetadata] = None,
192+
previous_hidden_states: Optional[torch.Tensor] = None,
193+
intermediate_tensors: Optional[IntermediateTensors] = None,
194+
inputs_embeds: Optional[torch.Tensor] = None,
195+
spec_step_idx: int = 0,
196+
) -> torch.Tensor:
197+
hidden_states = self.model(input_ids, positions, kv_caches,
198+
attn_metadata, previous_hidden_states,
199+
inputs_embeds, spec_step_idx)
200+
return hidden_states

vllm_ascend/patch/worker/patch_common/patch_multi_step_worker.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,15 +61,19 @@ def sampler_output(
6161
else:
6262
# Here we run multi-step directly, with every step prepared
6363
# on the CPU.
64-
# TODO: Remove this branch once DraftModelRunner supports TP>1
64+
# TODO Remove this branch once DraftModelRunner supports TP>1
6565
# and other restrictions that are part of DraftModelRunner's
6666
# supports_gpu_multi_step(..)
67+
if expanded_request.previous_hidden_states is not None:
68+
self.worker.model_runner.return_hidden_states = True
6769
for _ in range(sample_len):
6870
model_output: List[SamplerOutput] = self.worker.execute_model(
6971
execute_model_req=expanded_request)
7072
assert (len(model_output) == 1
7173
), "composing multistep workers not supported"
7274
model_output = model_output[0]
75+
self._maybe_update_previous_hidden_states(model_output,
76+
expanded_request)
7377

7478
self._append_new_tokens(model_output,
7579
expanded_request.seq_group_metadata_list,

vllm_ascend/patch/worker/patch_common/patch_spec_decode_worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def create_worker(
9393

9494
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
9595
if draft_model_config.hf_config.model_type == "deepseek_mtp":
96-
num_spec_prefill_steps = num_speculative_tokens
96+
num_spec_prefill_steps = draft_model_config.hf_config.n_predict
9797

9898
proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
9999
proposer_worker, draft_tp, target_tp)

0 commit comments

Comments
 (0)