Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion vllm_ascend/ascend_forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ def set_ascend_forward_context(
batch_descriptor: Optional[BatchDescriptor] = None,
prefetch_stream: torch.npu.Stream = None,
model_instance: torch.nn.Module = None,
weight_prefetch_method: Optional[WeightPrefetchMethod] = None):
weight_prefetch_method: Optional[WeightPrefetchMethod] = None,
is_mtp_model=False):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
We add some additional param into forward_context.
Expand Down Expand Up @@ -154,6 +155,7 @@ def set_ascend_forward_context(
forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled
forward_context.model_instance = model_instance
forward_context.weight_prefetch_method = weight_prefetch_method
forward_context.is_mtp_model = is_mtp_model

# TODO(rjg-lyh): The current implementation is somewhat brute force and not elegant.
# It will be improved later by implementing operator fusion through the FX graph.
Expand Down
6 changes: 5 additions & 1 deletion vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
trans_rope_weight, transdata,
wait_for_kv_layer_from_connector)
from vllm_ascend.compilation.acl_graph import (get_graph_params,
get_mtp_graph_params,
update_graph_params_workspaces)
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
Expand Down Expand Up @@ -1174,8 +1175,11 @@ def _forward_decode(
"actual_seq_lengths": actual_seq_lengths,
"actual_seq_lengths_kv": decode_meta.seq_lens_list,
}
graph_params = get_graph_params()
forward_context: ForwardContext = get_forward_context()
if forward_context.is_mtp_model:
graph_params = get_mtp_graph_params()
else:
graph_params = get_graph_params()
if forward_context.capturing:
stream = torch_npu.npu.current_stream()

Expand Down
45 changes: 43 additions & 2 deletions vllm_ascend/compilation/acl_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,10 @@ def update_attn_params(update_stream, forward_context, runtime_shape):

def update_mla_attn_params(update_stream, forward_context, runtime_shape,
speculative_config):
graph_params = get_graph_params()
if forward_context.is_mtp_model:
graph_params = get_mtp_graph_params()
else:
graph_params = get_graph_params()
# FIXME: Behold! We are using a temporary hack here to update the args
# for each layer's attention op in the graph.
with torch.npu.stream(update_stream):
Expand All @@ -249,7 +252,8 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
softmax_lse) = param
seq_lens_list = forward_context.attn_metadata[
key].decode.seq_lens_list
if speculative_config and speculative_config.method == "deepseek_mtp":
if speculative_config and speculative_config.method == "deepseek_mtp" \
and not forward_context.is_mtp_model:
actual_seq_lengths = forward_context.attn_metadata[
key].decode.actual_seq_lengths_q
spec_multiple = speculative_config.num_speculative_tokens + 1
Expand Down Expand Up @@ -434,3 +438,40 @@ def update_graph_params_workspaces(num_tokens: int, workspace: Any):

def get_graph_params():
return _graph_params


@dataclass
class MTPGraphParams:
events: dict[int, list[torch.npu.ExternalEvent]]
workspaces: dict[int, torch.Tensor]
handles: dict[int, list[torch_npu._C._NPUTaskGroupHandle]]
attn_params: dict[int, list[tuple]]


_mtp_graph_params: Optional[MTPGraphParams] = None


def set_mtp_graph_params(aclgraph_capture_sizes: set[int]):
global _mtp_graph_params
if _mtp_graph_params is not None:
raise ValueError("MTPGraph parameters have already been set!")
_mtp_graph_params = MTPGraphParams(
{size: []
for size in aclgraph_capture_sizes},
{size: None
for size in aclgraph_capture_sizes},
{size: []
for size in aclgraph_capture_sizes},
{size: []
for size in aclgraph_capture_sizes},
)


def update_mtp_graph_params_workspaces(num_tokens: int, workspace: Any):
global _mtp_graph_params
if _mtp_graph_params is not None:
_mtp_graph_params.workspaces[num_tokens] = workspace


def get_mtp_graph_params():
return _mtp_graph_params
173 changes: 173 additions & 0 deletions vllm_ascend/patch/worker/patch_deepseek_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,22 @@
from vllm.model_executor.models.deepseek_v2 import DeepseekV2DecoderLayer
from vllm.model_executor.models.utils import maybe_prefix

from vllm_ascend.utils import vllm_version_is

if vllm_version_is("0.11.0"):
from collections.abc import Iterable

from vllm.compilation.decorators import support_torch_compile
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.deepseek_mtp import (
DeepSeekMTP, DeepSeekMultiTokenPredictor)
from vllm.model_executor.models.deepseek_v2 import \
get_spec_layer_idx_from_weight_name
from vllm.model_executor.models.interfaces import SupportsPP
from vllm.sequence import IntermediateTensors


class SharedHead(nn.Module):

Expand Down Expand Up @@ -51,4 +67,161 @@
topk_indices_buffer)


def predictor_forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
previous_hidden_states: torch.Tensor,
inputs_embeds: torch.Tensor | None = None,

Check failure on line 75 in vllm_ascend/patch/worker/patch_deepseek_mtp.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

X | Y syntax for unions requires Python 3.10 [syntax]
spec_step_index: int = 0,
) -> torch.Tensor:
assert inputs_embeds is not None
# masking inputs at position 0, as not needed by MTP
inputs_embeds = torch.where(positions.unsqueeze(-1) == 0, 0, inputs_embeds)
inputs_embeds = self.enorm(inputs_embeds)
previous_hidden_states = self.hnorm(previous_hidden_states)

hidden_states = self.eh_proj(
torch.cat([inputs_embeds, previous_hidden_states], dim=-1)
)

hidden_states, residual = self.mtp_block(positions=positions,
hidden_states=hidden_states,
residual=None)
hidden_states = residual + hidden_states
return hidden_states


@support_torch_compile
class CustomDeepSeekMTP(nn.Module, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.config = vllm_config.model_config.hf_config
self.model = DeepSeekMultiTokenPredictor(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "model"))

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,

Check failure on line 112 in vllm_ascend/patch/worker/patch_deepseek_mtp.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

X | Y syntax for unions requires Python 3.10 [syntax]
inputs_embeds: torch.Tensor | None = None,

Check failure on line 113 in vllm_ascend/patch/worker/patch_deepseek_mtp.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

X | Y syntax for unions requires Python 3.10 [syntax]
spec_step_idx: int = 0,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, hidden_states,
inputs_embeds, spec_step_idx)
return hidden_states

def compute_logits(
self,
hidden_states: torch.Tensor,
spec_step_idx: int = 0,
) -> torch.Tensor | None:

Check failure on line 124 in vllm_ascend/patch/worker/patch_deepseek_mtp.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

X | Y syntax for unions requires Python 3.10 [syntax]
return self.model.compute_logits(hidden_states, spec_step_idx)

def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
("fused_qkv_a_proj", "q_a_proj", 0),
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
]

expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts,
)

params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is None:
continue
name = self._rewrite_spec_layer_name(spec_layer, name)
for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict:
continue
name_mapped = name.replace(weight_name, param_name)

# QKV fusion is optional, fall back to normal
# weight loading if it's not enabled
if (
param_name == "fused_qkv_a_proj"
) and name_mapped not in params_dict:
continue
else:
name = name_mapped

# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue

param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)

param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id,
)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue

name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue

# According to DeepSeek-V3 Technical Report, MTP modules
# shares embedding layer. We only load the first weights.
if (
spec_layer != self.model.mtp_start_layer_idx
and ".layers" not in name):
continue

param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params


DeepSeekMultiTokenPredictorLayer.__init__ = predictor_init
DeepSeekMultiTokenPredictorLayer.forward = predictor_forward
if vllm_version_is("0.11.0"):
DeepSeekMTP = CustomDeepSeekMTP
Loading
Loading