From 19578225da13a4de4cfb60ce8a8812d6b5fb17be Mon Sep 17 00:00:00 2001 From: zhuohuan Date: Sat, 24 May 2025 00:05:08 +0800 Subject: [PATCH 01/11] [feat]: support dbo for deepseek Signed-off-by: zhuohuan --- vllm_ascend/attention/mla_v1.py | 51 +++- vllm_ascend/models/deepseek_v2.py | 340 +++++++++++++++++++++++++- vllm_ascend/multistream/__init__.py | 0 vllm_ascend/multistream/base.py | 27 ++ vllm_ascend/multistream/context.py | 62 +++++ vllm_ascend/multistream/decorator.py | 21 ++ vllm_ascend/multistream/layers.py | 42 ++++ vllm_ascend/multistream/metadata.py | 160 ++++++++++++ vllm_ascend/multistream/ms_split.py | 173 +++++++++++++ vllm_ascend/ops/fused_moe.py | 32 +++ vllm_ascend/worker/model_runner_v1.py | 1 + 11 files changed, 901 insertions(+), 8 deletions(-) create mode 100644 vllm_ascend/multistream/__init__.py create mode 100644 vllm_ascend/multistream/base.py create mode 100644 vllm_ascend/multistream/context.py create mode 100644 vllm_ascend/multistream/decorator.py create mode 100644 vllm_ascend/multistream/layers.py create mode 100644 vllm_ascend/multistream/metadata.py create mode 100644 vllm_ascend/multistream/ms_split.py diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index ae3dd6205b..cd4b746992 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -15,6 +15,9 @@ from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla +from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig +from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn + if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.worker.gpu_input_batch import InputBatch @@ -117,6 +120,7 @@ class AscendMLAMetadata: with_prefill_across_dp: bool = False + query_lens: list[int] = None # The dimension of the attention heads head_dim: Optional[int] = None attn_mask: torch.Tensor = None @@ -135,6 +139,17 @@ def __post_init__(self): # f"Only {supported_head_sizes} are supported for head_dim,", # f"received {self.head_dim}.") + def split_metadata_for_multistream( + self, + ms_split_config: MSAttentionMetadataSplitConfig, + ) -> list["AscendMLAMetadata"]: + """Split metadata for multi-stream with AscendMLAMetadata""" + return model_input_split_v1_mla_attn( + ms_split_config=ms_split_config, + attn_metadata=self, + _metadata_cls=AscendMLAMetadata, + ) + M = TypeVar("M", bound=AscendMLAMetadata) @@ -386,6 +401,8 @@ def build( return self.metadata_cls( # type: ignore num_actual_tokens=num_actual_tokens, + query_lens=query_lens.tolist(), + seq_lens=seq_lens, slot_mapping=slot_mapping, head_dim=self.runner.model_config.get_head_size(), num_decodes=self._num_decodes, @@ -397,7 +414,6 @@ def build( decode=decode_metadata, query_start_loc=query_start_loc, block_tables=block_table, - seq_lens=seq_lens, with_prefill_across_dp=with_prefill_across_dp, ) @@ -811,16 +827,37 @@ def forward( key_cache=kv_cache, slot_indices=attn_metadata.slot_mapping.flatten()) if has_prefill: - output[num_decode_tokens:] = self._forward_prefill( - prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, - attn_metadata) + # FIX: aicore move/copy should be also placed on the comm stream in dbo, + # otherwise it may affect the accuracy or disturb the overlap of next stage + # TODO: use an elegant way here to avoid it + output_prefill = self._forward_prefill(prefill_q, + prefill_k_c_normed, + prefill_k_pe, kv_cache, + attn_metadata) + from vllm.multistream.context import get_multistream_comm_context + current_ms_metadata = get_multistream_comm_context() + if current_ms_metadata is not None: + with torch.npu.stream(current_ms_metadata.comm_stream): + output[num_decode_tokens:] = output_prefill + current_ms_metadata.after_comm_event.record() + else: + output[num_decode_tokens:] = output_prefill if has_decode: if self.running_in_graph: return self._forward_decode(decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe, kv_cache, attn_metadata) else: - output[:num_decode_tokens] = self._forward_decode( - decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe, - kv_cache, attn_metadata) + from vllm.multistream.context import get_multistream_comm_context + current_ms_metadata = get_multistream_comm_context() + output_decode = self._forward_decode(decode_ql_nope, + decode_q_pe, + decode_k_nope, + decode_k_pe, kv_cache, + attn_metadata) + if current_ms_metadata is not None: + with torch.npu.stream(current_ms_metadata.comm_stream): + output[:num_decode_tokens] = output_decode + else: + output[:num_decode_tokens] = output_decode return output_padded diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index 8a1b8d29fb..4d85d7f704 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -71,6 +71,13 @@ from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod from vllm_ascend.utils import dispose_tensor +from vllm_ascend.multistream.context import (set_multistream_context,get_multistream_layer_context, + advance_step_multistream_layer_context, get_multistream_comm_context) +from vllm_ascend.multistream.layers import (MultiStreamPreTransformerLayer, MultiStreamPostTransformerLayer) +from vllm_ascend.multistream.metadata import make_multistream_metadata_ds, MultiStreamStepMetadata +from vllm_ascend.multistream.base import MSEventKey +from vllm_ascend.multistream.ms_split import compute_split_seq_index + VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2 @@ -312,6 +319,46 @@ def forward( return hidden_states.view(num_tokens, hidden_size) + # ----------------------------------------- TBO-related -------------------------------------------- + def _forward_ms_op_shared_expert( + self, + hidden_states: torch.Tensor, + ): + shared_output = self.shared_experts(hidden_states) + return shared_output + + def _forward_ms_op_gate( + self, + hidden_states: torch.Tensor, + ): + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + return router_logits + + def _forward_ms_op_tp_allreduce( + self, + hidden_states: torch.Tensor, + shared_output: torch.Tensor, + chunk_hidden_states: torch.Tensor, + num_tokens: int = 0, + hidden_dim: int = 0, + ): + + if self.tp_size > 1: + dist.all_gather(list(chunk_hidden_states), hidden_states, + self.tp_group) + final_hidden_states = torch.cat(chunk_hidden_states, dim=0) + if num_tokens < self.tp_size: + final_hidden_states = final_hidden_states[:num_tokens] + else: + final_hidden_states = hidden_states + + if shared_output is not None: + final_hidden_states = final_hidden_states + shared_output + final_hidden_states = final_hidden_states.view(num_tokens, hidden_dim) + + return final_hidden_states + class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention): @@ -605,6 +652,206 @@ def forward( return hidden_states, residual + # ----------------------------------------- TBO-related -------------------------------------------- + def _forward_ms_layer( + self, + positions: List[torch.Tensor], + hidden_states: List[torch.Tensor], + residual: Optional[List[torch.Tensor]], + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[List[AttentionMetadata]] = None, + is_prefill: bool = False, + ) -> List[torch.Tensor]: + layer_index, ms_metadata, attn_metadata = get_multistream_layer_context() + assert layer_index >= 0 and ms_metadata is not None + num_micro_batchs = ms_metadata.ms_config.num_micro_batches + assert isinstance(self.mlp, CustomDeepseekV2MoE) + assert len(positions) == num_micro_batchs + assert len(hidden_states) == num_micro_batchs + assert len(residual) == num_micro_batchs + assert len(attn_metadata) == num_micro_batchs + num_tokens = [] + hidden_dims = [] + shared_outputs = [] + router_logits = [] + chunk_hidden_states = [] + ''' block 1 : attention + block 2 : attn tp communication, currently we switch to the comm stream + in tensor_model_parallel_all_reduce; + the attn computation of microbatch 1 can be overlapped with the moe + communication in the previous layer, and the attn computation of microbatch + 2 can be overlapped with the attn communication of microbatch 1 + ''' + for i in range(num_micro_batchs): + # wait last layer moe finishing communication + ms_metadata.try_wait_event(layer_index-1, i, MSEventKey.FFN_AR_FINISH) + context = MultiStreamStepMetadata( + comm_stream=ms_metadata.communicate_stream, + before_comm_event=ms_metadata.ms_events[layer_index][i][MSEventKey.ATTN_COM_FINISH], + after_comm_event=ms_metadata.ms_events[layer_index][i][MSEventKey.ATTN_AR_FINISH], + ) + + with set_multistream_context(context, i): + # input layernorm + hidden_states[i], residual[i] = self._forward_ms_op_input_layernorm(hidden_states[i], residual[i]) + # attention and tp allreducea + hidden_states[i], residual[i] = self._forward_ms_op_attn(positions[i], hidden_states[i], residual[i], kv_cache, attn_metadata[i]) + + ''' block 3 : shared experts + if there is an allreduce ops in shared expert, we can overlap it with the computation of the + shared expert for next microbatch or moe gating + ''' + for i in range(num_micro_batchs): + context = MultiStreamStepMetadata( + comm_stream=ms_metadata.communicate_stream, + before_comm_event=ms_metadata.ms_events[layer_index][i][MSEventKey.MOE_SE_COMP_FINISH], + after_comm_event=ms_metadata.ms_events[layer_index][i][MSEventKey.MOE_SE_COMM_FINISH], + ) + with set_multistream_context(context, i): + # compute shared expert after finishing ATTN AR + ms_metadata.try_wait_event(layer_index, i, MSEventKey.ATTN_AR_FINISH) + hidden_states[i], residual[i] = self._forward_ms_op_post_attn_layernorm(hidden_states[i], residual[i]) + + + num_token, hidden_dim = hidden_states[i].shape + hidden_states[i] = hidden_states[i].view(-1, hidden_dim) + num_tokens.append(num_token) + hidden_dims.append(hidden_dim) + if self.mlp.n_shared_experts is not None: + # TODO: we can move shared expert computation into next block if reduce results is false + shared_output = self.mlp._forward_ms_op_shared_expert(hidden_states[i]) + shared_outputs.append(shared_output) + + # block 4 : moe + for i in range(num_micro_batchs): + #ms_metadata.try_wait_event(layer_index, i, MSEventKey.MOE_SE_COMM_FINISH) + # when profile runs, force experts to load balanced tokens + # to avoid high memory consumption on a single rank. + # TODO: need a better flag to indicate whether in profile run or not. + if attn_metadata[i] is None: + # for profile run + is_prefill = True + enable_force_load_balance = True + else: + is_prefill = attn_metadata[i].num_prefills > 0 + enable_force_load_balance = False + + if self.mlp.tp_size > 1: + if num_tokens[i] < self.mlp.tp_size: + target_size = self.mlp.tp_size + new_hidden_states = torch.empty([target_size, hidden_dims[i]], + dtype=hidden_states[i].dtype, + device=hidden_states[i].device) + new_hidden_states[:num_tokens[i]] = hidden_states[i] + hidden_states[i] = new_hidden_states + chunk_hidden_state = torch.tensor_split(hidden_states[i], + self.mlp.tp_size, + dim=0) + chunk_hidden_states.append(chunk_hidden_state) + local_hidden_states = chunk_hidden_state[self.mlp.tp_rank] + else: + local_hidden_states = hidden_states + + router_logit = self.mlp._forward_ms_op_gate(local_hidden_states) + router_logits.append(router_logit) + + + if CustomDeepseekV2MoE.top_k: + real_top_k = CustomDeepseekV2MoE.top_k + else: + real_top_k = self.mlp.experts.top_k + + if VLLM_ENABLE_MC2 and not is_prefill: + ... + + hidden_states[i] = self.mlp.experts._forward_ms_fused_moe_comp(hidden_states[i], router_logits[i], is_prefill, real_top_k, enable_force_load_balance) + + if VLLM_ENABLE_MC2 and not is_prefill: + ... + + ''' the following kernels will be submitted to the comm stream to overlap the computation of the + moe computation of next microbatch and the attn computation of next layer + ''' + context = MultiStreamStepMetadata( + comm_stream=ms_metadata.communicate_stream, + before_comm_event=ms_metadata.ms_events[layer_index][i][MSEventKey.FFN_COM_FINISH], + after_comm_event=ms_metadata.ms_events[layer_index][i][MSEventKey.MOE_AFTER_COMM], + ) + with set_multistream_context(context, i): + if self.mlp.experts.reduce_results and (self.mlp.experts.tp_size > 1 or self.mlp.experts.ep_size > 1): + hidden_states[i] = tensor_model_parallel_all_reduce( + hidden_states[i]) + # check here + hidden_states[i] = hidden_states[i] * self.mlp.routed_scaling_factor + context = MultiStreamStepMetadata( + comm_stream=ms_metadata.communicate_stream, + before_comm_event=ms_metadata.ms_events[layer_index][i][MSEventKey.MOE_AFTER_COMM], + after_comm_event=ms_metadata.ms_events[layer_index][i][MSEventKey.FFN_AR_FINISH], + ) + with set_multistream_context(context, i): + hidden_states[i] = self.mlp._forward_ms_op_tp_allreduce(hidden_states[i], shared_outputs[i], chunk_hidden_states[i], num_tokens[i], hidden_dims[i]) + with torch.npu.stream(ms_metadata.communicate_stream): + # last + if isinstance( + self.mlp, + CustomDeepseekV2MLP) and hidden_states[i].dtype == torch.float16: + # Fix FP16 overflow + # Scaling the DeepseekV2MLP output, it is the input of + # input_layernorm of next decoder layer. + # The scaling of DeepseekV2MOE output would be done in the forward + # of DeepseekV2MOE + hidden_states[i] *= 1. / self.routed_scaling_factor + return hidden_states, residual + # should split ops in Decoder Layer + def _forward_ms_op_input_layernorm( + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ): + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + return hidden_states, residual + def _forward_ms_op_attn( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None, + ): + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + if hidden_states.dtype == torch.float16: + # Fix FP16 overflow + # We scale both hidden_states and residual before + # rmsnorm, and rmsnorm result would not affect by scale. + hidden_states *= 1. / self.routed_scaling_factor + if self.layer_idx == 0: + # The residual is shared by all layers, we only scale it on + # first layer. + residual *= 1. / self.routed_scaling_factor + return hidden_states, residual + + def _forward_ms_op_post_attn_layernorm( + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ): + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + return hidden_states, residual + + + + class CustomDeepseekV2Model(nn.Module): @@ -620,6 +867,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size + self.first_k_dense_replace = config.first_k_dense_replace if get_pp_group().is_first_rank: self.embed_tokens = VocabParallelEmbedding( @@ -648,6 +896,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) + + # tbo related members + self.multistream_config = vllm_config.model_config.multistream_config + self.use_mla = model_config.use_mla + self.multistream_metadata = make_multistream_metadata_ds( + start_layer=self.start_layer + self.first_k_dense_replace, + end_layer=self.end_layer, + causal_lm=getattr(config, "causal_lm", True), + multistream_config=self.multistream_config, + ) + self.ms_pre_layer = MultiStreamPreTransformerLayer(self.multistream_metadata) + self.ms_post_layer = MultiStreamPostTransformerLayer(self.multistream_metadata) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -672,13 +932,29 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): + num_normal_layers = ( + self.first_k_dense_replace + if self.multistream_config is not None and self.can_run_ms() + else self.end_layer - self.start_layer + ) + # if we enable multistream/dbo, only process dense layers here + for i in range(self.start_layer, self.start_layer + num_normal_layers): layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, residual, kv_caches[i - self.start_layer] if kv_caches is not None else None, attn_metadata) + + moe_start_layer = self.start_layer + num_normal_layers + hidden_states, residual = self._forward_ms_layers( + positions=positions, + hidden_states=hidden_states, + residual=residual, + moe_start_layer=moe_start_layer, + attn_metadata=attn_metadata, + kv_caches=kv_caches, + ) if not get_pp_group().is_last_rank: return IntermediateTensors({ @@ -689,6 +965,68 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states + def can_run_ms(self): + # currently we only enable prefill overlap + attn_metadata = get_forward_context().attn_metadata + dp_metadata = get_forward_context().dp_metadata + # profile run + if self.multistream_config is None or attn_metadata is None: + return False + # support mla attention and V1 engine at present + if not self.use_mla or not envs.VLLM_USE_V1: + return False + # disable decode dbo + if attn_metadata.num_prefills == 0: + return False + num_microbatchs = self.multistream_config.num_micro_batches + # check whether there is a dp rank that not use dual batch + if dp_metadata is not None: + for i in range(num_microbatchs): + cu_tokens = dp_metadata.cu_dbo_tokens_across_dp_cpu[i] + if torch.any(cu_tokens == 0).item(): + return False + [token_index, seq_index] = compute_split_seq_index(attn_metadata.query_lens, + attn_metadata.attn_state, attn_metadata.num_decode_tokens) + if token_index == 0 or seq_index == 0 or seq_index == len(attn_metadata.query_lens): + return False + # check whether the total tokens exceed the threshold + if attn_metadata.num_actual_tokens < self.multistream_config.min_total_tokens_to_split: + return False + return True + def _forward_ms_layers( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor, + moe_start_layer: int, + attn_metadata:Optional[AttentionMetadata] = None, + kv_caches: Optional[List[torch.Tensor]] = None, + is_prefill: bool = False, + ): + + if moe_start_layer == self.end_layer: + return hidden_states, residual + + attn_metadata, [positions, hidden_states, residual] = self.ms_pre_layer( + [positions, hidden_states, residual], + ) + # the rest layers + for i in range(moe_start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer._forward_ms_layer( + positions=positions, + hidden_states=hidden_states, + residual=residual, + kv_cache=kv_caches[i - + self.start_layer] if kv_caches is not None else None, + attn_metadata=attn_metadata, + is_prefill=is_prefill) + advance_step_multistream_layer_context() + + [hidden_states, residual] = self.ms_post_layer( + [hidden_states, residual], + ) + return hidden_states, residual class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM): # add `packed_modules_mapping` in `DeepseekV2ForCausalLM` to support weight merging diff --git a/vllm_ascend/multistream/__init__.py b/vllm_ascend/multistream/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/vllm_ascend/multistream/base.py b/vllm_ascend/multistream/base.py new file mode 100644 index 0000000000..5eb89e6dd1 --- /dev/null +++ b/vllm_ascend/multistream/base.py @@ -0,0 +1,27 @@ +from dataclasses import dataclass +from enum import Enum + +# TODO: move this part to vllm +class MSEventKey(Enum): + ATTN_COM_FINISH = 0 + ATTN_AR_FINISH = 1 + FFN_COM_FINISH = 2 + FFN_AR_FINISH = 3 + # events for MOE dispatch and combine + MOE_BEFORE_COMM = 4 + MOE_AFTER_COMM = 5 + # events for shared expert + MOE_SE_COMM_FINISH = 6 + MOE_SE_COMP_FINISH = 7 + MOE_GATE_FINISH = 8 +@dataclass +class MSAttentionMetadataSplitConfig: + """ + micro batch split config for split attention metadata + """ + # micro batch num + num_micro_batches: int = 2 + # split micro batches only when total tokens >= min_total_tokens_to_split + min_total_tokens_to_split: int = 256, + # split micro batches only when prefill tokens >= min_prefill_tokens_to_split + min_prefill_tokens_to_split: int = 64, \ No newline at end of file diff --git a/vllm_ascend/multistream/context.py b/vllm_ascend/multistream/context.py new file mode 100644 index 0000000000..21c9001464 --- /dev/null +++ b/vllm_ascend/multistream/context.py @@ -0,0 +1,62 @@ +from contextlib import contextmanager +from typing import Any + +# TODO: move this part to vllm + +_ms_comm_context: Any = None +_cur_micro_batch_num: int = -1 +_ms_layer_index_context: int = -1 +_ms_metadata_context: Any = None +_ms_attn_metadata_context: Any = None + +def set_multistream_layer_context(start_layer: int, ms_metadata: Any, attn_metadata: Any): + """ + set multistream layer context before transformer layers + """ + global _ms_layer_index_context, _ms_metadata_context, _ms_attn_metadata_context + _ms_layer_index_context = start_layer + _ms_metadata_context = ms_metadata + _ms_attn_metadata_context = attn_metadata + +def reset_multistream_layer_context(): + """ + reset multistream layer context + """ + global _ms_layer_index_context, _ms_metadata_context + _ms_layer_index_context = -1 + _ms_metadata_context = None + _ms_attn_metadata_context = None + +def get_multistream_layer_context(): + """ + get multistream layer context + """ + return _ms_layer_index_context, _ms_metadata_context, _ms_attn_metadata_context + +def advance_step_multistream_layer_context(): + """ + advance multistream layer index context + """ + global _ms_layer_index_context + _ms_layer_index_context += 1 + + +def get_multistream_comm_context() -> Any: + """Get the current comm forward context.""" + return _ms_comm_context + +def get_multistream_microbatch_context() -> int: + return _cur_micro_batch_num + +@contextmanager +def set_multistream_context(context: Any, micro_batch_num: int): + """A context manager that stores the current comm forward context, + can be attention metadata, etc.""" + global _ms_comm_context, _cur_micro_batch_num + _ms_comm_context = context + _cur_micro_batch_num = micro_batch_num + try: + yield + finally: + _ms_comm_context = None + _cur_micro_batch_num = -1 diff --git a/vllm_ascend/multistream/decorator.py b/vllm_ascend/multistream/decorator.py new file mode 100644 index 0000000000..705a0bc697 --- /dev/null +++ b/vllm_ascend/multistream/decorator.py @@ -0,0 +1,21 @@ +from .context import (get_multistream_layer_context, + get_multistream_microbatch_context) +from vllm.logger import init_logger + +# TODO: move this part to vllm + +logger = init_logger(__name__) + +# vllm v1 use get_forward_context to get the attn_metadata, +# we update it to the splitted version if enable dbo +def set_multistream_support(): + def decorator(func): + def wrapper(): + context = func() + layer_index, ms_metadata, attn_metadata = get_multistream_layer_context() + micro_batch_num = get_multistream_microbatch_context() + if layer_index != -1 and micro_batch_num != -1: + context.attn_metadata = attn_metadata[micro_batch_num] + return context + return wrapper + return decorator \ No newline at end of file diff --git a/vllm_ascend/multistream/layers.py b/vllm_ascend/multistream/layers.py new file mode 100644 index 0000000000..25468f4312 --- /dev/null +++ b/vllm_ascend/multistream/layers.py @@ -0,0 +1,42 @@ +import torch +from typing import List, Union, Tuple +from vllm.forward_context import get_forward_context +from .base import MSEventKey +from .metadata import MultiStreamMetadata +from vllm_ascend.multistream.context import (set_multistream_layer_context, reset_multistream_layer_context, + get_multistream_layer_context) + +# TODO: move this part to vllm +class MultiStreamPreTransformerLayer(torch.nn.Module): + def __init__(self, multistream_metadata: MultiStreamMetadata): + super().__init__() + self.multistream_metadata = multistream_metadata + def forward(self, + intput_tensors: List[torch.Tensor],): + attn_metadata = get_forward_context().attn_metadata + if self.multistream_metadata is None or attn_metadata is None: + set_multistream_layer_context(-1, None, None) + return attn_metadata, intput_tensors + # TODO add attn_metadata management + do_ms, attn_metadata, intput_tensors, _ = self.multistream_metadata.split_micro_batch(attn_metadata, intput_tensors) + if do_ms: + set_multistream_layer_context(self.multistream_metadata.start_layer, self.multistream_metadata, attn_metadata) + else: + set_multistream_layer_context(-1, None, None) + return attn_metadata, intput_tensors +class MultiStreamPostTransformerLayer(torch.nn.Module): + def __init__(self, multistream_metadata: MultiStreamMetadata): + super().__init__() + self.multistream_metadata = multistream_metadata + def forward(self, input_tensors: Union[List[Tuple[torch.Tensor]], List[torch.Tensor], List[List[torch.Tensor]]], + wait_layer_index: int = None): + if self.multistream_metadata is None: + return input_tensors + layer_index, ms_metadata, ms_attn_metadata = get_multistream_layer_context() + if layer_index >= 0: + true_wait_layer = self.multistream_metadata.end_layer-1 if wait_layer_index is None else wait_layer_index + self.multistream_metadata.try_wait_event(true_wait_layer, + self.multistream_metadata.ms_config.num_micro_batches-1, + MSEventKey.FFN_AR_FINISH) + reset_multistream_layer_context() + return self.multistream_metadata.merge_micro_batches(input_tensors) \ No newline at end of file diff --git a/vllm_ascend/multistream/metadata.py b/vllm_ascend/multistream/metadata.py new file mode 100644 index 0000000000..dc6f718bfa --- /dev/null +++ b/vllm_ascend/multistream/metadata.py @@ -0,0 +1,160 @@ +from dataclasses import dataclass +import torch +from typing import Dict, List, Optional, Union, Tuple +from vllm.sequence import IntermediateTensors +from vllm.config import MultiStreamConfig +from .base import MSAttentionMetadataSplitConfig, MSEventKey +from vllm.attention.backends.abstract import AttentionMetadata + +# TODO: move this part to vllm +def split_micro_batches_tensors(input_tensors, split_index: int, keys: List[str] = None): + if isinstance(input_tensors, list): + micro_batches = [] + for tensor in input_tensors: + if tensor is None: + micro_batches.append([None, None]) + else: + micro_batches.append([tensor[:split_index], tensor[split_index:]]) + return micro_batches + elif isinstance(input_tensors, torch.Tensor): + return [input_tensors[:split_index], input_tensors[split_index:]] + elif input_tensors is None: + return [None, None] + elif isinstance(input_tensors, Dict): + assert keys is not None + micro_batches_pre = {} + for key in keys: + micro_batches_pre[key] = input_tensors[key][:split_index] + micro_batches_post = {} + for key in keys: + micro_batches_post[key] = input_tensors[key][split_index:] + return [micro_batches_pre, micro_batches_post] + else: + raise NotImplementedError +def make_multistream_metadata( + start_layer: int, + end_layer: int, + causal_lm: bool = True, + multistream_config: Optional[MultiStreamConfig] = None, +): + if multistream_config is None: + return None + return MultiStreamMetadata( + calculate_stream=torch.npu.current_stream(), + communicate_stream=torch.npu.Stream(), + start_layer=start_layer, + end_layer=end_layer, + multistream_config=multistream_config, + event_keys=[MSEventKey.ATTN_COM_FINISH, MSEventKey.ATTN_AR_FINISH, + MSEventKey.FFN_COM_FINISH, MSEventKey.FFN_AR_FINISH], + causal_lm=causal_lm, + ) +def make_multistream_metadata_ds( + start_layer: int, + end_layer: int, + causal_lm: bool = True, + multistream_config: Optional[MultiStreamConfig] = None, +): + if multistream_config is None: + return None + event_keylist = [ + MSEventKey.ATTN_COM_FINISH, + MSEventKey.ATTN_AR_FINISH, + MSEventKey.FFN_COM_FINISH, + MSEventKey.FFN_AR_FINISH, + MSEventKey.MOE_BEFORE_COMM, + MSEventKey.MOE_AFTER_COMM, + MSEventKey.MOE_SE_COMM_FINISH, + MSEventKey.MOE_SE_COMP_FINISH, + MSEventKey.MOE_GATE_FINISH, + ] + return MultiStreamMetadata( + calculate_stream=torch.npu.current_stream(), + communicate_stream=torch.npu.Stream(), + start_layer=start_layer, + end_layer=end_layer, + multistream_config=multistream_config, + event_keys=event_keylist, + causal_lm=causal_lm, + ) +@dataclass +class MultiStreamStepMetadata: + comm_stream: torch.npu.Stream = None + before_comm_event: torch.npu.Event = None + after_comm_event: torch.npu.Event = None +class MultiStreamMetadata: + # direct stream + calculate_stream = None + # delay stream + communicate_stream = None + # events + ms_events: Dict[int, Dict[int, Dict[MSEventKey, torch.npu.Event]]] = {} + # multi-stream-flag + enable_multi_stream: bool = False + + def __init__(self, + calculate_stream: torch.npu.Stream, + communicate_stream: torch.npu.Stream, + start_layer: int, + end_layer: int, + event_keys: List[MSEventKey], + multistream_config: Optional[MultiStreamConfig], + causal_lm: bool = True, + ): + self.calculate_stream = calculate_stream + self.communicate_stream = communicate_stream + self.start_layer = start_layer + self.end_layer = end_layer + self.ms_config = multistream_config + self.causal_lm = causal_lm + self._build_events(event_keys) + self._build_ms_split_config() + def _build_events(self, event_keys): + for i in range(self.start_layer - 1, self.end_layer): + self.ms_events[i] = {} + for j in range(self.ms_config.num_micro_batches): + self.ms_events[i][j] = {} + for key in event_keys: + self.ms_events[i][j][key] = torch.npu.Event() + def _build_ms_split_config(self): + self.ms_split_config = MSAttentionMetadataSplitConfig( + num_micro_batches=self.ms_config.num_micro_batches, + min_total_tokens_to_split=self.ms_config.min_total_tokens_to_split, + min_prefill_tokens_to_split=self.ms_config.min_prefill_tokens_to_split, + ) + def try_wait_event(self, layer_index: int, micro_batch_index: int, event_key: MSEventKey): + self.ms_events[layer_index][micro_batch_index][event_key].wait() + def try_record_event(self, layer_index: int, micro_batch_index: int, event_key: MSEventKey): + self.ms_events[layer_index][micro_batch_index][event_key].record() + def split_micro_batch(self, + attn_metadata: "AttentionMetadata", + intput_tensors: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, + intermediate_tensors_keys: Optional[List[str]] = None, + ) -> Tuple[bool, + Union[AttentionMetadata, List[AttentionMetadata]], + Union[List[torch.Tensor], List[List[torch.Tensor]]], + Union[IntermediateTensors, List[IntermediateTensors]]]: + attn_metadata = attn_metadata.split_metadata_for_multistream(self.ms_split_config) + if len(attn_metadata) == 1: + return False, attn_metadata[0], intput_tensors, intermediate_tensors + split_index = attn_metadata[0].slot_mapping.shape[0] + input_tensors = split_micro_batches_tensors(intput_tensors, split_index) + if intermediate_tensors is not None: + inter_tensors_list = split_micro_batches_tensors(intermediate_tensors.tensors, split_index, intermediate_tensors_keys) + intermediate_tensors = [ + IntermediateTensors(inter_tensors) for inter_tensors in inter_tensors_list + ] + return True, attn_metadata, input_tensors, intermediate_tensors + def merge_micro_batches(self, + input_tensors: Union[List[torch.Tensor], List[List[torch.Tensor]]] + ) -> List[torch.Tensor]: + if input_tensors is None or isinstance(input_tensors[0], torch.Tensor): + return input_tensors + batch = [] + for tensors in input_tensors: + if tensors is None or tensors[0] is None: + batch.append(None) + else: + batch.append(torch.cat(tensors, dim=0)) + return batch \ No newline at end of file diff --git a/vllm_ascend/multistream/ms_split.py b/vllm_ascend/multistream/ms_split.py new file mode 100644 index 0000000000..583e8a7383 --- /dev/null +++ b/vllm_ascend/multistream/ms_split.py @@ -0,0 +1,173 @@ +import torch +from typing import Optional, Any, List +import numpy as np +from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig +def compute_split_seq_index( + query_lens: Optional[list[int]], + attn_state: AscendAttentionState, + num_tokens: int, + imbalance_ratio: float = 0.1, + )->Optional[list[int]]: + if attn_state == AscendAttentionState.PrefillOnly or attn_state == AscendAttentionState.ChunkedPrefill: + assert query_lens is not None + total_tokens = sum(query_lens) + # the first index in last split + tokens, split_index = 0, 0 + for value in query_lens: + tokens += value + split_index += 1 + if tokens >= total_tokens // 2 : + # check the current split index + if abs(tokens - total_tokens // 2) < total_tokens * imbalance_ratio: + return [tokens, split_index] + # check the previous split index + elif abs(tokens - total_tokens // 2 - value) < total_tokens * imbalance_ratio: + return [tokens-value, split_index-1] + # fail to split if it is imbalanced + # TODO: split tokens in seq + else : + return [0, 0] + elif attn_state == AscendAttentionState.DecodeOnly: + tokens = num_tokens // 2 + return [tokens, tokens] + else: + return [0, 0] +def split_attn_tensor_type( + input_tensor: torch.Tensor, + index: int, + )-> List[torch.Tensor]: + return [input_tensor[:index], input_tensor[index:]] +def split_attn_int_type( + var: int, + index: int, + )-> List[torch.Tensor]: + return [min(var,index), max(var-index, 0)] + +def model_input_split_v1_mla_attn( + attn_metadata, + _metadata_cls, + ms_split_config: MSAttentionMetadataSplitConfig, + ) -> List[Any]: + assert 0 < ms_split_config.num_micro_batches < 3 + [token_index, seq_index] = compute_split_seq_index(attn_metadata.query_lens, + attn_metadata.attn_state, attn_metadata.num_decode_tokens) + if token_index == 0 or seq_index == 0 or seq_index == len(attn_metadata.query_lens): + return [attn_metadata] + + query_start_loc_cpu = np.zeros(shape=(len(attn_metadata.query_lens) + 1,), dtype=int) + np.cumsum(attn_metadata.query_lens, out=query_start_loc_cpu[1:]) + if attn_metadata.num_prefills > 0: + prefill_query_start_loc = np.zeros(shape=(len(attn_metadata.prefill.query_lens) + 1,), dtype=int) + np.cumsum(attn_metadata.prefill.query_lens, out=prefill_query_start_loc[1:]) + + # split attn metadata + [slot_mapping_pre, slot_mapping_post] = split_attn_tensor_type(attn_metadata.slot_mapping,token_index) + [num_decodes_pre, num_decodes_post] = split_attn_int_type(attn_metadata.num_decodes, seq_index) + [num_decode_tokens_pre, num_decode_tokens_post] = split_attn_int_type(attn_metadata.num_decode_tokens, token_index) + [num_prefills_pre, num_prefills_post] = split_attn_int_type(attn_metadata.num_prefills, max(0,seq_index-attn_metadata.num_decodes)) + seq_lens = attn_metadata.prefill.seq_lens if attn_metadata.num_prefills>0 else attn_metadata.decode.seq_lens + [seq_lens_pre, seq_lens_post] = split_attn_tensor_type(seq_lens,seq_index) + + if attn_metadata.attn_state == AscendAttentionState.PrefillOnly: + # the attn_mla kernel in torch npu only accept 128*128 attn mask + attn_mask_pre = attn_mask_post = attn_metadata.attn_mask + attn_state_pre = attn_state_post = AscendAttentionState.PrefillOnly + elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: + # should be none in decode only state + attn_mask_pre = attn_mask_post = attn_metadata.attn_mask + attn_state_pre = attn_state_post = AscendAttentionState.DecodeOnly + else : + # chunked prefill + if num_prefills_pre > 0: + attn_state_pre = attn_state_post = AscendAttentionState.ChunkedPrefill + attn_mask_pre = attn_metadata.attn_mask[:token_index, :max(seq_lens_pre)].contiguous() + attn_state_post = AscendAttentionState.ChunkedPrefill + attn_mask_post = attn_metadata.attn_mask[token_index:, :max(seq_lens_post)].contiguous() + else: + attn_state_pre = AscendAttentionState.DecodeOnly + attn_mask_pre = None + attn_state_post = AscendAttentionState.ChunkedPrefill + attn_mask_post = attn_metadata.attn_mask[token_index:, :max(seq_lens_post)].contiguous() + from vllm_ascend.attention.mla_v1 import AscendMLAPrefillMetadata, AscendMLADecodeMetadata + if num_prefills_pre > 0: + # split metadata.prefill + [input_positions_pre, input_positions_post] = split_attn_tensor_type(attn_metadata.prefill.input_positions, token_index - attn_metadata.num_decode_tokens) + [block_tables_pre, block_tables_post] = split_attn_tensor_type(attn_metadata.prefill.block_table, seq_index- attn_metadata.num_decodes) + [prefill_query_lens_pre, prefill_query_lens_post] = split_attn_tensor_type(attn_metadata.prefill.query_lens,seq_index- attn_metadata.num_decodes) + context_len_pre = seq_lens_pre[attn_metadata.num_decodes:] + context_len_post = seq_lens_post + prefill_max_query_len_pre = max(prefill_query_lens_pre) + prefill_max_query_len_post = max(prefill_query_lens_post) + prefill_pre = AscendMLAPrefillMetadata( + attn_mask=attn_mask_pre, + query_lens=prefill_query_lens_pre, + seq_lens=seq_lens_pre, + input_positions=input_positions_pre, + context_lens=context_len_pre, + block_table=block_tables_pre, + max_query_len=prefill_max_query_len_pre, + max_seq_lens=context_len_pre.max().item(), + ) + prefill_post = AscendMLAPrefillMetadata( + attn_mask=attn_mask_post, + query_lens=prefill_query_lens_post, + seq_lens=seq_lens_post, + input_positions=input_positions_post, + context_lens=context_len_post, + block_table=block_tables_post, + max_query_len=prefill_max_query_len_post, + max_seq_lens=context_len_post.max().item(), + ) + decode_pre = attn_metadata.decode + decode_post = None + else : + # prefill is None, split metadata.decode + [input_positions_pre, input_positions_post] = split_attn_tensor_type(attn_metadata.decode.input_positions, token_index ) + [block_tables_pre, block_tables_post] = split_attn_tensor_type(attn_metadata.decode.block_table, seq_index) + [decode_seq_lens_pre, decode_seq_lens_post] = split_attn_tensor_type(seq_lens,seq_index) + decode_pre = AscendMLADecodeMetadata( + input_positions=input_positions_pre, + block_table=block_tables_pre, + seq_lens=decode_seq_lens_pre, + max_seq_lens=max(decode_seq_lens_pre), + seq_lens_list=decode_seq_lens_pre.tolist(), + ) + decode_post = AscendMLADecodeMetadata( + input_positions=input_positions_post, + block_table=block_tables_post, + seq_lens=decode_seq_lens_post, + max_seq_lens=max(decode_seq_lens_post), + seq_lens_list=decode_seq_lens_post.tolist(), + ) + prefill_pre = None + prefill_post = attn_metadata.prefill + # construct metadata + from vllm_ascend.attention.mla_v1 import AscendMLAPrefillMetadata + attention_metadata_pre = _metadata_cls( + num_actual_tokens=token_index, + num_input_tokens=token_index, + head_dim=attn_metadata.head_dim, + slot_mapping=slot_mapping_pre, + num_decodes=num_decodes_pre, + num_prefills=num_prefills_pre, + num_decode_tokens=num_decode_tokens_pre, + attn_state = attn_state_pre, + attn_mask = attn_mask_pre, + prefill=prefill_pre, + decode=decode_pre, + ) + attention_metadata_post = _metadata_cls( + num_actual_tokens=attn_metadata.num_actual_tokens - token_index, + num_input_tokens = attn_metadata.num_input_tokens - token_index, + head_dim = attn_metadata.head_dim, + slot_mapping=slot_mapping_post, + num_decodes=num_decodes_post, + num_prefills=num_prefills_post, + num_decode_tokens=num_decode_tokens_post, + attn_mask = attn_mask_post, + attn_state = attn_state_post, + prefill=prefill_post, + decode=decode_post, + ) + return [attention_metadata_pre, attention_metadata_post] diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 6aff62fc62..f3d3920fd9 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -31,6 +31,9 @@ from vllm.model_executor.layers.quantization.base_config import \ QuantizationConfig +from vllm.multistream.context import set_multistream_context, get_multistream_comm_context +from vllm_ascend.multistream.base import MSEventKey +from vllm_ascend.multistream.metadata import MultiStreamStepMetadata, MultiStreamMetadata import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group @@ -896,3 +899,32 @@ def forward(self, if self.enable_multistream_shared_expert and not is_prefill: return hidden_states, shared_output return hidden_states + + # ----------------------------------------- TBO-related -------------------------------------------- + + def _forward_ms_fused_moe_comp( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + is_prefill: bool, + real_top_k, + enable_force_load_balance: bool = False, + ): + hidden_states = self.quant_method.apply( + layer=self, + x=hidden_states, + router_logits=router_logits, + top_k=real_top_k, + renormalize=self.renormalize, + use_grouped_topk=self.use_grouped_topk, + global_num_experts=self.global_num_experts, + expert_map=self.expert_map, + topk_group=self.topk_group, + num_expert_group=self.num_expert_group, + custom_routing_function=self.custom_routing_function, + scoring_func=self.scoring_func, + e_score_correction_bias=self.e_score_correction_bias, + is_prefill=is_prefill, + enable_force_load_balance=enable_force_load_balance) + + return hidden_states diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 07ea679312..76406d6cdc 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -706,6 +706,7 @@ def _process_reqs( # Run forward pass with set_forward_context(attn_metadata, self.vllm_config, + query_lens=self.query_lens, num_tokens=num_input_tokens): with ProfileExecuteDuration().capture_async("forward"): model_kwargs = {} From b1b8d6dd33cc5a333295a72cf5eb218557ecef0e Mon Sep 17 00:00:00 2001 From: zhuohuan Date: Mon, 26 May 2025 23:16:37 +0800 Subject: [PATCH 02/11] [fix]: reduced dependency on vllm for dbo Signed-off-by: zhuohuan --- vllm_ascend/attention/mla_v1.py | 47 ++++++++------ vllm_ascend/models/deepseek_v2.py | 48 ++++++++++----- vllm_ascend/multistream/metadata.py | 89 ++++++++++++--------------- vllm_ascend/multistream/ms_split.py | 11 ++-- vllm_ascend/ops/fused_moe.py | 1 - vllm_ascend/worker/model_runner_v1.py | 1 - 6 files changed, 106 insertions(+), 91 deletions(-) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index cd4b746992..2ccec20a41 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -830,34 +830,41 @@ def forward( # FIX: aicore move/copy should be also placed on the comm stream in dbo, # otherwise it may affect the accuracy or disturb the overlap of next stage # TODO: use an elegant way here to avoid it - output_prefill = self._forward_prefill(prefill_q, - prefill_k_c_normed, - prefill_k_pe, kv_cache, - attn_metadata) - from vllm.multistream.context import get_multistream_comm_context + from vllm_ascend.multistream.context import get_multistream_comm_context current_ms_metadata = get_multistream_comm_context() - if current_ms_metadata is not None: + if current_ms_metadata is None: + output[num_decode_tokens:] = self._forward_prefill( + prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, + attn_metadata) + else: + current_ms_metadata.before_comm_event.record() with torch.npu.stream(current_ms_metadata.comm_stream): - output[num_decode_tokens:] = output_prefill + current_ms_metadata.before_comm_event.wait() + output[num_decode_tokens:] = self._forward_prefill( + prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, + attn_metadata) current_ms_metadata.after_comm_event.record() - else: - output[num_decode_tokens:] = output_prefill + if has_decode: if self.running_in_graph: return self._forward_decode(decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe, kv_cache, attn_metadata) else: - from vllm.multistream.context import get_multistream_comm_context + + from vllm_ascend.multistream.context import get_multistream_comm_context current_ms_metadata = get_multistream_comm_context() - output_decode = self._forward_decode(decode_ql_nope, - decode_q_pe, - decode_k_nope, - decode_k_pe, kv_cache, - attn_metadata) - if current_ms_metadata is not None: - with torch.npu.stream(current_ms_metadata.comm_stream): - output[:num_decode_tokens] = output_decode - else: - output[:num_decode_tokens] = output_decode + if current_ms_metadata is None: + output[:num_decode_tokens] = self._forward_decode( + decode_ql_nope, decode_q_pe, decode_k_nope, + decode_k_pe, kv_cache, attn_metadata) + else: + current_ms_metadata.before_comm_event.record() + with torch.npu.stream(current_ms_metadata.comm_stream): + current_ms_metadata.before_comm_event.wait() + output[:num_decode_tokens] = self._forward_decode( + decode_ql_nope, decode_q_pe, decode_k_nope, + decode_k_pe, kv_cache, attn_metadata) + current_ms_metadata.after_comm_event.record() + return output_padded diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index 4d85d7f704..26247c0e61 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -74,11 +74,12 @@ from vllm_ascend.multistream.context import (set_multistream_context,get_multistream_layer_context, advance_step_multistream_layer_context, get_multistream_comm_context) from vllm_ascend.multistream.layers import (MultiStreamPreTransformerLayer, MultiStreamPostTransformerLayer) -from vllm_ascend.multistream.metadata import make_multistream_metadata_ds, MultiStreamStepMetadata +from vllm_ascend.multistream.metadata import make_multistream_metadata_ds, MultiStreamStepMetadata, MultiStreamConfig from vllm_ascend.multistream.base import MSEventKey from vllm_ascend.multistream.ms_split import compute_split_seq_index VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2 +VLLM_ENABLE_MS: bool = envs_ascend.VLLM_ENABLE_MS class CustomDeepseekV2MLP(nn.Module): @@ -348,8 +349,10 @@ def _forward_ms_op_tp_allreduce( dist.all_gather(list(chunk_hidden_states), hidden_states, self.tp_group) final_hidden_states = torch.cat(chunk_hidden_states, dim=0) - if num_tokens < self.tp_size: - final_hidden_states = final_hidden_states[:num_tokens] + #if num_tokens < self.tp_size: + # final_hidden_states = final_hidden_states[:num_tokens] + if num_tokens > 0: + final_hidden_states = final_hidden_states[:-num_tokens] else: final_hidden_states = hidden_states @@ -692,6 +695,10 @@ def _forward_ms_layer( ) with set_multistream_context(context, i): + context = get_forward_context() + layer_index, ms_metadata, attn_metadata = get_multistream_layer_context() + context.attn_metadata = attn_metadata[i] + # input layernorm hidden_states[i], residual[i] = self._forward_ms_op_input_layernorm(hidden_states[i], residual[i]) # attention and tp allreducea @@ -715,7 +722,7 @@ def _forward_ms_layer( num_token, hidden_dim = hidden_states[i].shape hidden_states[i] = hidden_states[i].view(-1, hidden_dim) - num_tokens.append(num_token) + #num_tokens.append(num_token) hidden_dims.append(hidden_dim) if self.mlp.n_shared_experts is not None: # TODO: we can move shared expert computation into next block if reduce results is false @@ -737,13 +744,20 @@ def _forward_ms_layer( enable_force_load_balance = False if self.mlp.tp_size > 1: - if num_tokens[i] < self.mlp.tp_size: - target_size = self.mlp.tp_size - new_hidden_states = torch.empty([target_size, hidden_dims[i]], - dtype=hidden_states[i].dtype, - device=hidden_states[i].device) - new_hidden_states[:num_tokens[i]] = hidden_states[i] - hidden_states[i] = new_hidden_states + #if num_tokens[i] < self.mlp.tp_size: + # target_size = self.mlp.tp_size + # new_hidden_states = torch.empty([target_size, hidden_dims[i]], + # dtype=hidden_states[i].dtype, + # device=hidden_states[i].device) + # new_hidden_states[:num_tokens[i]] = hidden_states[i] + # hidden_states[i] = new_hidden_states + num_token, _ = hidden_states[i].shape + padded_num_tokens = (self.mlp.tp_size - + num_token % self.mlp.tp_size) % self.mlp.tp_size + if padded_num_tokens > 0: + hidden_states[i] = nn.functional.pad(hidden_states[i], + (0, 0, 0, padded_num_tokens)) + num_tokens.append(padded_num_tokens) chunk_hidden_state = torch.tensor_split(hidden_states[i], self.mlp.tp_size, dim=0) @@ -764,7 +778,7 @@ def _forward_ms_layer( if VLLM_ENABLE_MC2 and not is_prefill: ... - hidden_states[i] = self.mlp.experts._forward_ms_fused_moe_comp(hidden_states[i], router_logits[i], is_prefill, real_top_k, enable_force_load_balance) + hidden_states[i] = self.mlp.experts._forward_ms_fused_moe_comp(local_hidden_states, router_logits[i], is_prefill, real_top_k, enable_force_load_balance) if VLLM_ENABLE_MC2 and not is_prefill: ... @@ -898,7 +912,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ["hidden_states", "residual"], config.hidden_size)) # tbo related members - self.multistream_config = vllm_config.model_config.multistream_config + if VLLM_ENABLE_MS: + self.multistream_config = MultiStreamConfig() + else: + self.multistream_config = None self.use_mla = model_config.use_mla self.multistream_metadata = make_multistream_metadata_ds( start_layer=self.start_layer + self.first_k_dense_replace, @@ -980,13 +997,14 @@ def can_run_ms(self): return False num_microbatchs = self.multistream_config.num_micro_batches # check whether there is a dp rank that not use dual batch - if dp_metadata is not None: + '''if dp_metadata is not None: for i in range(num_microbatchs): cu_tokens = dp_metadata.cu_dbo_tokens_across_dp_cpu[i] if torch.any(cu_tokens == 0).item(): return False [token_index, seq_index] = compute_split_seq_index(attn_metadata.query_lens, - attn_metadata.attn_state, attn_metadata.num_decode_tokens) + attn_metadata.attn_state, attn_metadata.num_decode_tokens) + ''' if token_index == 0 or seq_index == 0 or seq_index == len(attn_metadata.query_lens): return False # check whether the total tokens exceed the threshold diff --git a/vllm_ascend/multistream/metadata.py b/vllm_ascend/multistream/metadata.py index dc6f718bfa..79ed4e33fa 100644 --- a/vllm_ascend/multistream/metadata.py +++ b/vllm_ascend/multistream/metadata.py @@ -2,7 +2,6 @@ import torch from typing import Dict, List, Optional, Union, Tuple from vllm.sequence import IntermediateTensors -from vllm.config import MultiStreamConfig from .base import MSAttentionMetadataSplitConfig, MSEventKey from vllm.attention.backends.abstract import AttentionMetadata @@ -31,57 +30,21 @@ def split_micro_batches_tensors(input_tensors, split_index: int, keys: List[str] return [micro_batches_pre, micro_batches_post] else: raise NotImplementedError -def make_multistream_metadata( - start_layer: int, - end_layer: int, - causal_lm: bool = True, - multistream_config: Optional[MultiStreamConfig] = None, -): - if multistream_config is None: - return None - return MultiStreamMetadata( - calculate_stream=torch.npu.current_stream(), - communicate_stream=torch.npu.Stream(), - start_layer=start_layer, - end_layer=end_layer, - multistream_config=multistream_config, - event_keys=[MSEventKey.ATTN_COM_FINISH, MSEventKey.ATTN_AR_FINISH, - MSEventKey.FFN_COM_FINISH, MSEventKey.FFN_AR_FINISH], - causal_lm=causal_lm, - ) -def make_multistream_metadata_ds( - start_layer: int, - end_layer: int, - causal_lm: bool = True, - multistream_config: Optional[MultiStreamConfig] = None, -): - if multistream_config is None: - return None - event_keylist = [ - MSEventKey.ATTN_COM_FINISH, - MSEventKey.ATTN_AR_FINISH, - MSEventKey.FFN_COM_FINISH, - MSEventKey.FFN_AR_FINISH, - MSEventKey.MOE_BEFORE_COMM, - MSEventKey.MOE_AFTER_COMM, - MSEventKey.MOE_SE_COMM_FINISH, - MSEventKey.MOE_SE_COMP_FINISH, - MSEventKey.MOE_GATE_FINISH, - ] - return MultiStreamMetadata( - calculate_stream=torch.npu.current_stream(), - communicate_stream=torch.npu.Stream(), - start_layer=start_layer, - end_layer=end_layer, - multistream_config=multistream_config, - event_keys=event_keylist, - causal_lm=causal_lm, - ) + @dataclass class MultiStreamStepMetadata: comm_stream: torch.npu.Stream = None before_comm_event: torch.npu.Event = None after_comm_event: torch.npu.Event = None + +@dataclass +class MultiStreamConfig: + """Controls the behavior of multi-stream models.""" + min_total_tokens_to_split: int = 256 + min_prefill_tokens_to_split: int = 64 + num_micro_batches: int = 2 + imbalance_ratio: float = 0.1 + class MultiStreamMetadata: # direct stream calculate_stream = None @@ -157,4 +120,34 @@ def merge_micro_batches(self, batch.append(None) else: batch.append(torch.cat(tensors, dim=0)) - return batch \ No newline at end of file + return batch + + +def make_multistream_metadata_ds( + start_layer: int, + end_layer: int, + causal_lm: bool = True, + multistream_config: Optional[MultiStreamConfig] = None, +): + if multistream_config is None: + return None + event_keylist = [ + MSEventKey.ATTN_COM_FINISH, + MSEventKey.ATTN_AR_FINISH, + MSEventKey.FFN_COM_FINISH, + MSEventKey.FFN_AR_FINISH, + MSEventKey.MOE_BEFORE_COMM, + MSEventKey.MOE_AFTER_COMM, + MSEventKey.MOE_SE_COMM_FINISH, + MSEventKey.MOE_SE_COMP_FINISH, + MSEventKey.MOE_GATE_FINISH, + ] + return MultiStreamMetadata( + calculate_stream=torch.npu.current_stream(), + communicate_stream=torch.npu.Stream(), + start_layer=start_layer, + end_layer=end_layer, + multistream_config=multistream_config, + event_keys=event_keylist, + causal_lm=causal_lm, + ) \ No newline at end of file diff --git a/vllm_ascend/multistream/ms_split.py b/vllm_ascend/multistream/ms_split.py index 583e8a7383..636bc86e1d 100644 --- a/vllm_ascend/multistream/ms_split.py +++ b/vllm_ascend/multistream/ms_split.py @@ -9,7 +9,7 @@ def compute_split_seq_index( num_tokens: int, imbalance_ratio: float = 0.1, )->Optional[list[int]]: - if attn_state == AscendAttentionState.PrefillOnly or attn_state == AscendAttentionState.ChunkedPrefill: + if attn_state != AscendAttentionState.DecodeOnly: assert query_lens is not None total_tokens = sum(query_lens) # the first index in last split @@ -28,11 +28,10 @@ def compute_split_seq_index( # TODO: split tokens in seq else : return [0, 0] - elif attn_state == AscendAttentionState.DecodeOnly: + else: tokens = num_tokens // 2 return [tokens, tokens] - else: - return [0, 0] + def split_attn_tensor_type( input_tensor: torch.Tensor, index: int, @@ -69,10 +68,10 @@ def model_input_split_v1_mla_attn( seq_lens = attn_metadata.prefill.seq_lens if attn_metadata.num_prefills>0 else attn_metadata.decode.seq_lens [seq_lens_pre, seq_lens_post] = split_attn_tensor_type(seq_lens,seq_index) - if attn_metadata.attn_state == AscendAttentionState.PrefillOnly: + if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache or attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit: # the attn_mla kernel in torch npu only accept 128*128 attn mask attn_mask_pre = attn_mask_post = attn_metadata.attn_mask - attn_state_pre = attn_state_post = AscendAttentionState.PrefillOnly + attn_state_pre = attn_state_post = attn_metadata.attn_state elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: # should be none in decode only state attn_mask_pre = attn_mask_post = attn_metadata.attn_mask diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index f3d3920fd9..c9cd9e7f82 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -31,7 +31,6 @@ from vllm.model_executor.layers.quantization.base_config import \ QuantizationConfig -from vllm.multistream.context import set_multistream_context, get_multistream_comm_context from vllm_ascend.multistream.base import MSEventKey from vllm_ascend.multistream.metadata import MultiStreamStepMetadata, MultiStreamMetadata import vllm_ascend.envs as envs_ascend diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 76406d6cdc..07ea679312 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -706,7 +706,6 @@ def _process_reqs( # Run forward pass with set_forward_context(attn_metadata, self.vllm_config, - query_lens=self.query_lens, num_tokens=num_input_tokens): with ProfileExecuteDuration().capture_async("forward"): model_kwargs = {} From 92f813c496f7988eed45f59a52762dc66b39ac63 Mon Sep 17 00:00:00 2001 From: zhuohuan Date: Tue, 27 May 2025 21:42:36 +0800 Subject: [PATCH 03/11] [feat]: improve overlap performance Signed-off-by: zhuohuan --- vllm_ascend/attention/mla_v1.py | 84 +++--- vllm_ascend/models/deepseek_v2.py | 388 +++++++++++++++++---------- vllm_ascend/multistream/base.py | 5 +- vllm_ascend/multistream/decorator.py | 2 +- vllm_ascend/multistream/layers.py | 15 +- vllm_ascend/multistream/metadata.py | 18 +- vllm_ascend/multistream/ms_split.py | 12 +- vllm_ascend/ops/fused_moe.py | 2 - 8 files changed, 334 insertions(+), 192 deletions(-) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 2ccec20a41..e78ef4ca5a 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -4,6 +4,12 @@ import numpy as np import torch import torch_npu +from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig +from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn +from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla +from vllm_ascend.worker.model_runner_v1 import NPUModelRunner + from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, AttentionMetadata, MLAAttentionImpl) @@ -601,7 +607,18 @@ def _forward_prefill( ) attn_output = attn_output.reshape( [num_tokens, self.num_heads * self.v_head_dim]) - return self.o_proj(attn_output)[0] + + # A better way is to modify the communication ops or RowParallel Layer in vllm; + from vllm_ascend.multistream.context import \ + get_multistream_comm_context + current_ms_metadata = get_multistream_comm_context() + if current_ms_metadata is None: + return self.o_proj(attn_output)[0] + else: + current_ms_metadata.before_comm_event.record() + with torch.npu.stream(current_ms_metadata.comm_stream): + current_ms_metadata.before_comm_event.wait() + return self.o_proj(attn_output)[0] def exec_kv( self, @@ -701,7 +718,16 @@ def _forward_decode( context_lens=attn_metadata.decode.seq_lens, # type:ignore mla_vheadsize=self.kv_lora_rank, out=attn_output) - return self._v_up_proj_and_o_proj(attn_output) + from vllm_ascend.multistream.context import \ + get_multistream_comm_context + current_ms_metadata = get_multistream_comm_context() + if current_ms_metadata is None: + return self._v_up_proj_and_o_proj(attn_output) + else: + current_ms_metadata.before_comm_event.record() + with torch.npu.stream(current_ms_metadata.comm_stream): + current_ms_metadata.before_comm_event.wait() + return self._v_up_proj_and_o_proj(attn_output) def forward( self, @@ -827,23 +853,22 @@ def forward( key_cache=kv_cache, slot_indices=attn_metadata.slot_mapping.flatten()) if has_prefill: - # FIX: aicore move/copy should be also placed on the comm stream in dbo, - # otherwise it may affect the accuracy or disturb the overlap of next stage - # TODO: use an elegant way here to avoid it - from vllm_ascend.multistream.context import get_multistream_comm_context + # FIX: aicore move should be also placed on the comm stream in dbo, + # otherwise it may affect the accuracy + # TODO: use an elegant way to overlap + from vllm_ascend.multistream.context import \ + get_multistream_comm_context + output_prefill = self._forward_prefill(prefill_q, + prefill_k_c_normed, + prefill_k_pe, kv_cache, + attn_metadata) current_ms_metadata = get_multistream_comm_context() - if current_ms_metadata is None: - output[num_decode_tokens:] = self._forward_prefill( - prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, - attn_metadata) - else: - current_ms_metadata.before_comm_event.record() + if current_ms_metadata is not None: with torch.npu.stream(current_ms_metadata.comm_stream): - current_ms_metadata.before_comm_event.wait() - output[num_decode_tokens:] = self._forward_prefill( - prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, - attn_metadata) + output[num_decode_tokens:] = output_prefill current_ms_metadata.after_comm_event.record() + else: + output[num_decode_tokens:] = output_prefill if has_decode: if self.running_in_graph: @@ -851,20 +876,19 @@ def forward( decode_k_nope, decode_k_pe, kv_cache, attn_metadata) else: - - from vllm_ascend.multistream.context import get_multistream_comm_context + from vllm_ascend.multistream.context import \ + get_multistream_comm_context + output_decode = self._forward_decode(decode_ql_nope, + decode_q_pe, + decode_k_nope, + decode_k_pe, kv_cache, + attn_metadata) current_ms_metadata = get_multistream_comm_context() - if current_ms_metadata is None: - output[:num_decode_tokens] = self._forward_decode( - decode_ql_nope, decode_q_pe, decode_k_nope, - decode_k_pe, kv_cache, attn_metadata) - else: - current_ms_metadata.before_comm_event.record() - with torch.npu.stream(current_ms_metadata.comm_stream): - current_ms_metadata.before_comm_event.wait() - output[:num_decode_tokens] = self._forward_decode( - decode_ql_nope, decode_q_pe, decode_k_nope, - decode_k_pe, kv_cache, attn_metadata) - current_ms_metadata.after_comm_event.record() + if current_ms_metadata is not None: + with torch.npu.stream(current_ms_metadata.comm_stream): + output[:num_decode_tokens] = output_decode + current_ms_metadata.after_comm_event.record() + else: + output[:num_decode_tokens] = output_decode return output_padded diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index 26247c0e61..5dd7c2a007 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -30,9 +30,23 @@ import torch import torch.distributed as dist import torch_npu -import vllm.envs as envs +import vllm_ascend.envs as envs_ascend from torch import nn from transformers import PretrainedConfig +from vllm_ascend.multistream.base import MSEventKey +from vllm_ascend.multistream.context import ( + advance_step_multistream_layer_context, get_multistream_comm_context, + get_multistream_layer_context, set_multistream_context) +from vllm_ascend.multistream.layers import (MultiStreamPostTransformerLayer, + MultiStreamPreTransformerLayer) +from vllm_ascend.multistream.metadata import (MultiStreamConfig, + MultiStreamStepMetadata, + make_multistream_metadata_ds) +from vllm_ascend.multistream.ms_split import compute_split_seq_index +from vllm_ascend.ops.fused_moe import AscendFusedMoE +from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod + +import vllm.envs as envs from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import (get_pp_group, @@ -71,15 +85,17 @@ from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod from vllm_ascend.utils import dispose_tensor -from vllm_ascend.multistream.context import (set_multistream_context,get_multistream_layer_context, - advance_step_multistream_layer_context, get_multistream_comm_context) -from vllm_ascend.multistream.layers import (MultiStreamPreTransformerLayer, MultiStreamPostTransformerLayer) +from vllm_ascend.multistream.context import ( + set_multistream_context, get_multistream_layer_context, + advance_step_multistream_layer_context, get_multistream_comm_context) +from vllm_ascend.multistream.layers import (MultiStreamPreTransformerLayer, + MultiStreamPostTransformerLayer) from vllm_ascend.multistream.metadata import make_multistream_metadata_ds, MultiStreamStepMetadata, MultiStreamConfig from vllm_ascend.multistream.base import MSEventKey from vllm_ascend.multistream.ms_split import compute_split_seq_index VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2 -VLLM_ENABLE_MS: bool = envs_ascend.VLLM_ENABLE_MS +VLLM_ENABLE_DBO: bool = envs_ascend.VLLM_ENABLE_DBO class CustomDeepseekV2MLP(nn.Module): @@ -151,6 +167,50 @@ def forward(self, x): x, _ = self.down_proj(x) return x + def _forward_ms_mlp(self, x): + current_ms_metadata = get_multistream_comm_context() + assert current_ms_metadata is not None + if self.is_dynamic_quant: + x, dynamic_scale = torch_npu.npu_dynamic_quant(x) + x = torch_npu.npu_quant_matmul( + x, + self.gate_up_proj.weight, + self.gate_up_proj.weight_scale, + output_dtype=torch.int32, + ) + x, dynamic_scale = torch_npu.npu_dequant_swiglu_quant( + x=x, + weight_scale=self.gate_up_proj.weight_scale_fp32, + activation_scale=dynamic_scale, + bias=None, + quant_scale=None, + quant_offset=None, + group_index=None, + activate_left=True, + quant_mode=1) + x = torch_npu.npu_quant_matmul( + x, + self.down_proj.weight, + self.down_proj.weight_scale, + pertoken_scale=dynamic_scale, + output_dtype=torch.bfloat16, + ) + if self.down_proj.reduce_results and self.down_proj.tp_size > 1: + current_ms_metadata.before_comm_event.record() + with torch.npu.stream(current_ms_metadata.comm_stream): + current_ms_metadata.before_comm_event.wait() + x = tensor_model_parallel_all_reduce(x) + current_ms_metadata.after_comm_event.record() + return x + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + current_ms_metadata.before_comm_event.record() + with torch.npu.stream(current_ms_metadata.comm_stream): + current_ms_metadata.before_comm_event.wait() + x, _ = self.down_proj(x) + current_ms_metadata.after_comm_event.record() + return x + class CustomDeepseekV2MoE(nn.Module): @@ -322,44 +382,58 @@ def forward( # ----------------------------------------- TBO-related -------------------------------------------- def _forward_ms_op_shared_expert( - self, - hidden_states: torch.Tensor, - ): - shared_output = self.shared_experts(hidden_states) + self, + hidden_states: torch.Tensor, + ): + shared_output = self.shared_experts._forward_ms_mlp(hidden_states) return shared_output - + def _forward_ms_op_gate( - self, - hidden_states: torch.Tensor, - ): + self, + hidden_states: torch.Tensor, + ): # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) return router_logits - - def _forward_ms_op_tp_allreduce( - self, - hidden_states: torch.Tensor, - shared_output: torch.Tensor, - chunk_hidden_states: torch.Tensor, - num_tokens: int = 0, - hidden_dim: int = 0, - ): - + + def _forward_ms_op_tp_allgather( + self, + hidden_states: torch.Tensor, + shared_output: torch.Tensor, + chunk_hidden_states: torch.Tensor, + num_tokens: int = 0, + hidden_dim: int = 0, + ): + if self.tp_size > 1: - dist.all_gather(list(chunk_hidden_states), hidden_states, - self.tp_group) - final_hidden_states = torch.cat(chunk_hidden_states, dim=0) - #if num_tokens < self.tp_size: - # final_hidden_states = final_hidden_states[:num_tokens] - if num_tokens > 0: - final_hidden_states = final_hidden_states[:-num_tokens] + current_ms_metadata = get_multistream_comm_context() + if current_ms_metadata is None: + dist.all_gather(list(chunk_hidden_states), hidden_states, + self.tp_group) + final_hidden_states = torch.cat(chunk_hidden_states, dim=0) + #if num_tokens < self.tp_size: + # final_hidden_states = final_hidden_states[:num_tokens] + if num_tokens > 0: + final_hidden_states = final_hidden_states[:-num_tokens] + else: + current_ms_metadata.before_comm_event.record() + with torch.npu.stream(current_ms_metadata.comm_stream): + dist.all_gather(list(chunk_hidden_states), hidden_states, + self.tp_group) + final_hidden_states = torch.cat(chunk_hidden_states, dim=0) + #if num_tokens < self.tp_size: + # final_hidden_states = final_hidden_states[:num_tokens] + if num_tokens > 0: + final_hidden_states = final_hidden_states[:-num_tokens] + else: final_hidden_states = hidden_states if shared_output is not None: final_hidden_states = final_hidden_states + shared_output - final_hidden_states = final_hidden_states.view(num_tokens, hidden_dim) - + final_hidden_states = final_hidden_states.view( + num_tokens, hidden_dim) + return final_hidden_states @@ -661,18 +735,18 @@ def _forward_ms_layer( positions: List[torch.Tensor], hidden_states: List[torch.Tensor], residual: Optional[List[torch.Tensor]], + attn_metadata: List[AttentionMetadata], kv_cache: Optional[torch.Tensor] = None, - attn_metadata: Optional[List[AttentionMetadata]] = None, is_prefill: bool = False, - ) -> List[torch.Tensor]: - layer_index, ms_metadata, attn_metadata = get_multistream_layer_context() + ) -> tuple[List[torch.Tensor], List[torch.Tensor]]: + layer_index, ms_metadata, attn_metadata = get_multistream_layer_context( + ) assert layer_index >= 0 and ms_metadata is not None num_micro_batchs = ms_metadata.ms_config.num_micro_batches assert isinstance(self.mlp, CustomDeepseekV2MoE) assert len(positions) == num_micro_batchs assert len(hidden_states) == num_micro_batchs - assert len(residual) == num_micro_batchs - assert len(attn_metadata) == num_micro_batchs + assert attn_metadata is not None num_tokens = [] hidden_dims = [] shared_outputs = [] @@ -687,38 +761,49 @@ def _forward_ms_layer( ''' for i in range(num_micro_batchs): # wait last layer moe finishing communication - ms_metadata.try_wait_event(layer_index-1, i, MSEventKey.FFN_AR_FINISH) + ms_metadata.try_wait_event(layer_index - 1, i, + MSEventKey.FFN_AR_FINISH) context = MultiStreamStepMetadata( - comm_stream=ms_metadata.communicate_stream, - before_comm_event=ms_metadata.ms_events[layer_index][i][MSEventKey.ATTN_COM_FINISH], - after_comm_event=ms_metadata.ms_events[layer_index][i][MSEventKey.ATTN_AR_FINISH], - ) - + comm_stream=ms_metadata.communicate_stream, + before_comm_event=ms_metadata.ms_events[layer_index][i][ + MSEventKey.ATTN_COM_FINISH], + after_comm_event=ms_metadata.ms_events[layer_index][i][ + MSEventKey.ATTN_AR_FINISH], + ) + with set_multistream_context(context, i): - context = get_forward_context() - layer_index, ms_metadata, attn_metadata = get_multistream_layer_context() - context.attn_metadata = attn_metadata[i] + forward_context = get_forward_context() + layer_index, ms_metadata, attn_metadata = get_multistream_layer_context( + ) + forward_context.attn_metadata = attn_metadata[i] # input layernorm - hidden_states[i], residual[i] = self._forward_ms_op_input_layernorm(hidden_states[i], residual[i]) - # attention and tp allreducea - hidden_states[i], residual[i] = self._forward_ms_op_attn(positions[i], hidden_states[i], residual[i], kv_cache, attn_metadata[i]) - + hidden_states[i], residual[ + i] = self._forward_ms_op_input_layernorm( + hidden_states[i], residual[i]) + # attention and tp allreduce + hidden_states[i], residual[i] = self._forward_ms_op_attn( + positions[i], hidden_states[i], residual[i], kv_cache, + attn_metadata[i]) ''' block 3 : shared experts if there is an allreduce ops in shared expert, we can overlap it with the computation of the shared expert for next microbatch or moe gating ''' for i in range(num_micro_batchs): + ms_metadata.try_wait_event(layer_index, i, + MSEventKey.ATTN_AR_FINISH) context = MultiStreamStepMetadata( comm_stream=ms_metadata.communicate_stream, - before_comm_event=ms_metadata.ms_events[layer_index][i][MSEventKey.MOE_SE_COMP_FINISH], - after_comm_event=ms_metadata.ms_events[layer_index][i][MSEventKey.MOE_SE_COMM_FINISH], + before_comm_event=ms_metadata.ms_events[layer_index][i][ + MSEventKey.MOE_SE_COMP_FINISH], + after_comm_event=ms_metadata.ms_events[layer_index][i][ + MSEventKey.MOE_SE_COMM_FINISH], ) with set_multistream_context(context, i): # compute shared expert after finishing ATTN AR - ms_metadata.try_wait_event(layer_index, i, MSEventKey.ATTN_AR_FINISH) - hidden_states[i], residual[i] = self._forward_ms_op_post_attn_layernorm(hidden_states[i], residual[i]) - + hidden_states[i], residual[ + i] = self._forward_ms_op_post_attn_layernorm( + hidden_states[i], residual[i]) num_token, hidden_dim = hidden_states[i].shape hidden_states[i] = hidden_states[i].view(-1, hidden_dim) @@ -726,9 +811,10 @@ def _forward_ms_layer( hidden_dims.append(hidden_dim) if self.mlp.n_shared_experts is not None: # TODO: we can move shared expert computation into next block if reduce results is false - shared_output = self.mlp._forward_ms_op_shared_expert(hidden_states[i]) + shared_output = self.mlp._forward_ms_op_shared_expert( + hidden_states[i]) shared_outputs.append(shared_output) - + # block 4 : moe for i in range(num_micro_batchs): #ms_metadata.try_wait_event(layer_index, i, MSEventKey.MOE_SE_COMM_FINISH) @@ -752,11 +838,11 @@ def _forward_ms_layer( # new_hidden_states[:num_tokens[i]] = hidden_states[i] # hidden_states[i] = new_hidden_states num_token, _ = hidden_states[i].shape - padded_num_tokens = (self.mlp.tp_size - - num_token % self.mlp.tp_size) % self.mlp.tp_size + padded_num_tokens = (self.mlp.tp_size - num_token % + self.mlp.tp_size) % self.mlp.tp_size if padded_num_tokens > 0: - hidden_states[i] = nn.functional.pad(hidden_states[i], - (0, 0, 0, padded_num_tokens)) + hidden_states[i] = nn.functional.pad( + hidden_states[i], (0, 0, 0, padded_num_tokens)) num_tokens.append(padded_num_tokens) chunk_hidden_state = torch.tensor_split(hidden_states[i], self.mlp.tp_size, @@ -769,7 +855,6 @@ def _forward_ms_layer( router_logit = self.mlp._forward_ms_op_gate(local_hidden_states) router_logits.append(router_logit) - if CustomDeepseekV2MoE.top_k: real_top_k = CustomDeepseekV2MoE.top_k else: @@ -778,50 +863,65 @@ def _forward_ms_layer( if VLLM_ENABLE_MC2 and not is_prefill: ... - hidden_states[i] = self.mlp.experts._forward_ms_fused_moe_comp(local_hidden_states, router_logits[i], is_prefill, real_top_k, enable_force_load_balance) + hidden_states[i] = self.mlp.experts._forward_ms_fused_moe_comp( + local_hidden_states, router_logits[i], is_prefill, real_top_k, + enable_force_load_balance) if VLLM_ENABLE_MC2 and not is_prefill: ... - ''' the following kernels will be submitted to the comm stream to overlap the computation of the moe computation of next microbatch and the attn computation of next layer ''' context = MultiStreamStepMetadata( comm_stream=ms_metadata.communicate_stream, - before_comm_event=ms_metadata.ms_events[layer_index][i][MSEventKey.FFN_COM_FINISH], - after_comm_event=ms_metadata.ms_events[layer_index][i][MSEventKey.MOE_AFTER_COMM], + before_comm_event=ms_metadata.ms_events[layer_index][i][ + MSEventKey.FFN_COM_FINISH], + after_comm_event=ms_metadata.ms_events[layer_index][i][ + MSEventKey.MOE_AFTER_COMM], ) - with set_multistream_context(context, i): - if self.mlp.experts.reduce_results and (self.mlp.experts.tp_size > 1 or self.mlp.experts.ep_size > 1): + context.before_comm_event.record() + with torch.npu.stream(ms_metadata.communicate_stream): + #with set_multistream_context(context, i): + context.before_comm_event.wait() + if self.mlp.experts.reduce_results and ( + self.mlp.experts.tp_size > 1 + or self.mlp.experts.ep_size > 1): hidden_states[i] = tensor_model_parallel_all_reduce( hidden_states[i]) + context.after_comm_event.record() # check here - hidden_states[i] = hidden_states[i] * self.mlp.routed_scaling_factor + hidden_states[ + i] = hidden_states[i] * self.mlp.routed_scaling_factor context = MultiStreamStepMetadata( - comm_stream=ms_metadata.communicate_stream, - before_comm_event=ms_metadata.ms_events[layer_index][i][MSEventKey.MOE_AFTER_COMM], - after_comm_event=ms_metadata.ms_events[layer_index][i][MSEventKey.FFN_AR_FINISH], - ) + comm_stream=ms_metadata.communicate_stream, + before_comm_event=ms_metadata.ms_events[layer_index][i][ + MSEventKey.MOE_AFTER_COMM], + after_comm_event=ms_metadata.ms_events[layer_index][i][ + MSEventKey.FFN_AR_FINISH], + ) with set_multistream_context(context, i): - hidden_states[i] = self.mlp._forward_ms_op_tp_allreduce(hidden_states[i], shared_outputs[i], chunk_hidden_states[i], num_tokens[i], hidden_dims[i]) + hidden_states[i] = self.mlp._forward_ms_op_tp_allgather( + hidden_states[i], shared_outputs[i], + chunk_hidden_states[i], num_tokens[i], hidden_dims[i]) with torch.npu.stream(ms_metadata.communicate_stream): # last - if isinstance( - self.mlp, - CustomDeepseekV2MLP) and hidden_states[i].dtype == torch.float16: - # Fix FP16 overflow - # Scaling the DeepseekV2MLP output, it is the input of - # input_layernorm of next decoder layer. - # The scaling of DeepseekV2MOE output would be done in the forward - # of DeepseekV2MOE + if isinstance(self.mlp, CustomDeepseekV2MLP + ) and hidden_states[i].dtype == torch.float16: + # Fix FP16 overflow + # Scaling the DeepseekV2MLP output, it is the input of + # input_layernorm of next decoder layer. + # The scaling of DeepseekV2MOE output would be done in the forward + # of DeepseekV2MOE hidden_states[i] *= 1. / self.routed_scaling_factor + context.after_comm_event.record() return hidden_states, residual + # should split ops in Decoder Layer def _forward_ms_op_input_layernorm( - self, - hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], - ): + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -829,14 +929,15 @@ def _forward_ms_op_input_layernorm( hidden_states, residual = self.input_layernorm( hidden_states, residual) return hidden_states, residual + def _forward_ms_op_attn( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], - kv_cache: Optional[torch.Tensor] = None, - attn_metadata: Optional[AttentionMetadata] = None, - ): + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor, + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, @@ -853,18 +954,15 @@ def _forward_ms_op_attn( # first layer. residual *= 1. / self.routed_scaling_factor return hidden_states, residual - + def _forward_ms_op_post_attn_layernorm( - self, - hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], - ): - hidden_states, residual = self.post_attention_layernorm( + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ): + hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) - return hidden_states, residual - - - + return hidden_states, residual class CustomDeepseekV2Model(nn.Module): @@ -910,9 +1008,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) - + # tbo related members - if VLLM_ENABLE_MS: + if VLLM_ENABLE_DBO: self.multistream_config = MultiStreamConfig() else: self.multistream_config = None @@ -923,8 +1021,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): causal_lm=getattr(config, "causal_lm", True), multistream_config=self.multistream_config, ) - self.ms_pre_layer = MultiStreamPreTransformerLayer(self.multistream_metadata) - self.ms_post_layer = MultiStreamPostTransformerLayer(self.multistream_metadata) + self.ms_pre_layer = MultiStreamPreTransformerLayer( + self.multistream_metadata) + self.ms_post_layer = MultiStreamPostTransformerLayer( + self.multistream_metadata) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -949,11 +1049,10 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - num_normal_layers = ( - self.first_k_dense_replace - if self.multistream_config is not None and self.can_run_ms() - else self.end_layer - self.start_layer - ) + num_normal_layers = (self.first_k_dense_replace + if self.multistream_config is not None + and self.can_run_ms() else self.end_layer - + self.start_layer) # if we enable multistream/dbo, only process dense layers here for i in range(self.start_layer, self.start_layer + num_normal_layers): layer = self.layers[i] @@ -962,14 +1061,13 @@ def forward( kv_caches[i - self.start_layer] if kv_caches is not None else None, attn_metadata) - + moe_start_layer = self.start_layer + num_normal_layers hidden_states, residual = self._forward_ms_layers( positions=positions, hidden_states=hidden_states, residual=residual, moe_start_layer=moe_start_layer, - attn_metadata=attn_metadata, kv_caches=kv_caches, ) @@ -985,7 +1083,7 @@ def forward( def can_run_ms(self): # currently we only enable prefill overlap attn_metadata = get_forward_context().attn_metadata - dp_metadata = get_forward_context().dp_metadata + # dp_metadata = get_forward_context().dp_metadata # profile run if self.multistream_config is None or attn_metadata is None: return False @@ -995,57 +1093,61 @@ def can_run_ms(self): # disable decode dbo if attn_metadata.num_prefills == 0: return False - num_microbatchs = self.multistream_config.num_micro_batches # check whether there is a dp rank that not use dual batch - '''if dp_metadata is not None: + ''' + num_microbatchs = self.multistream_config.num_micro_batches + if dp_metadata is not None: for i in range(num_microbatchs): cu_tokens = dp_metadata.cu_dbo_tokens_across_dp_cpu[i] if torch.any(cu_tokens == 0).item(): return False - [token_index, seq_index] = compute_split_seq_index(attn_metadata.query_lens, - attn_metadata.attn_state, attn_metadata.num_decode_tokens) - ''' - if token_index == 0 or seq_index == 0 or seq_index == len(attn_metadata.query_lens): + ''' + [token_index, + seq_index] = compute_split_seq_index(attn_metadata.query_lens, + attn_metadata.attn_state, + attn_metadata.num_decode_tokens) + if token_index == 0 or seq_index == 0 or seq_index == len( + attn_metadata.query_lens): return False # check whether the total tokens exceed the threshold if attn_metadata.num_actual_tokens < self.multistream_config.min_total_tokens_to_split: return False return True + def _forward_ms_layers( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - residual: torch.Tensor, - moe_start_layer: int, - attn_metadata:Optional[AttentionMetadata] = None, - kv_caches: Optional[List[torch.Tensor]] = None, - is_prefill: bool = False, + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor, + moe_start_layer: int, + kv_caches: Optional[List[torch.Tensor]] = None, + is_prefill: bool = False, ): - + if moe_start_layer == self.end_layer: return hidden_states, residual - - attn_metadata, [positions, hidden_states, residual] = self.ms_pre_layer( - [positions, hidden_states, residual], - ) + + attn_metadata, [positions, hidden_states, + residual] = self.ms_pre_layer( + [positions, hidden_states, residual], ) # the rest layers for i in range(moe_start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer._forward_ms_layer( - positions=positions, - hidden_states=hidden_states, + positions=positions, + hidden_states=hidden_states, residual=residual, - kv_cache=kv_caches[i - - self.start_layer] if kv_caches is not None else None, - attn_metadata=attn_metadata, + attn_metadata=attn_metadata, + kv_cache=kv_caches[i - self.start_layer] + if kv_caches is not None else None, is_prefill=is_prefill) advance_step_multistream_layer_context() - - [hidden_states, residual] = self.ms_post_layer( - [hidden_states, residual], - ) + + [hidden_states, + residual] = self.ms_post_layer([hidden_states, residual], ) return hidden_states, residual + class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM): # add `packed_modules_mapping` in `DeepseekV2ForCausalLM` to support weight merging packed_modules_mapping = { diff --git a/vllm_ascend/multistream/base.py b/vllm_ascend/multistream/base.py index 5eb89e6dd1..f407fd2f76 100644 --- a/vllm_ascend/multistream/base.py +++ b/vllm_ascend/multistream/base.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from enum import Enum + # TODO: move this part to vllm class MSEventKey(Enum): ATTN_COM_FINISH = 0 @@ -22,6 +23,6 @@ class MSAttentionMetadataSplitConfig: # micro batch num num_micro_batches: int = 2 # split micro batches only when total tokens >= min_total_tokens_to_split - min_total_tokens_to_split: int = 256, + min_total_tokens_to_split: int = 256 # split micro batches only when prefill tokens >= min_prefill_tokens_to_split - min_prefill_tokens_to_split: int = 64, \ No newline at end of file + min_prefill_tokens_to_split: int = 64 \ No newline at end of file diff --git a/vllm_ascend/multistream/decorator.py b/vllm_ascend/multistream/decorator.py index 705a0bc697..3c307b0fb4 100644 --- a/vllm_ascend/multistream/decorator.py +++ b/vllm_ascend/multistream/decorator.py @@ -7,7 +7,7 @@ logger = init_logger(__name__) # vllm v1 use get_forward_context to get the attn_metadata, -# we update it to the splitted version if enable dbo +# we can use this decorator to update the attn metadata def set_multistream_support(): def decorator(func): def wrapper(): diff --git a/vllm_ascend/multistream/layers.py b/vllm_ascend/multistream/layers.py index 25468f4312..374e55a146 100644 --- a/vllm_ascend/multistream/layers.py +++ b/vllm_ascend/multistream/layers.py @@ -1,10 +1,15 @@ +from typing import List, Tuple, Union, Optional + import torch -from typing import List, Union, Tuple +from vllm_ascend.multistream.context import (get_multistream_layer_context, + reset_multistream_layer_context, + set_multistream_layer_context) + from vllm.forward_context import get_forward_context + from .base import MSEventKey from .metadata import MultiStreamMetadata -from vllm_ascend.multistream.context import (set_multistream_layer_context, reset_multistream_layer_context, - get_multistream_layer_context) + # TODO: move this part to vllm class MultiStreamPreTransformerLayer(torch.nn.Module): @@ -29,8 +34,8 @@ def __init__(self, multistream_metadata: MultiStreamMetadata): super().__init__() self.multistream_metadata = multistream_metadata def forward(self, input_tensors: Union[List[Tuple[torch.Tensor]], List[torch.Tensor], List[List[torch.Tensor]]], - wait_layer_index: int = None): - if self.multistream_metadata is None: + wait_layer_index: Optional[int] = None): + if self.multistream_metadata is None or self.multistream_metadata.ms_config is None: return input_tensors layer_index, ms_metadata, ms_attn_metadata = get_multistream_layer_context() if layer_index >= 0: diff --git a/vllm_ascend/multistream/metadata.py b/vllm_ascend/multistream/metadata.py index 79ed4e33fa..d566a48ff9 100644 --- a/vllm_ascend/multistream/metadata.py +++ b/vllm_ascend/multistream/metadata.py @@ -1,12 +1,16 @@ from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Union + import torch -from typing import Dict, List, Optional, Union, Tuple + +from vllm.attention.backends.abstract import AttentionMetadata from vllm.sequence import IntermediateTensors + from .base import MSAttentionMetadataSplitConfig, MSEventKey -from vllm.attention.backends.abstract import AttentionMetadata + # TODO: move this part to vllm -def split_micro_batches_tensors(input_tensors, split_index: int, keys: List[str] = None): +def split_micro_batches_tensors(input_tensors, split_index: int, keys: Optional[List[str]] = None): if isinstance(input_tensors, list): micro_batches = [] for tensor in input_tensors: @@ -70,8 +74,10 @@ def __init__(self, self.end_layer = end_layer self.ms_config = multistream_config self.causal_lm = causal_lm - self._build_events(event_keys) - self._build_ms_split_config() + if self.ms_config is not None: + self._build_events(event_keys) + self._build_ms_split_config() + def _build_events(self, event_keys): for i in range(self.start_layer - 1, self.end_layer): self.ms_events[i] = {} @@ -114,7 +120,7 @@ def merge_micro_batches(self, ) -> List[torch.Tensor]: if input_tensors is None or isinstance(input_tensors[0], torch.Tensor): return input_tensors - batch = [] + batch: List[Optional[torch.Tensor]] = [] for tensors in input_tensors: if tensors is None or tensors[0] is None: batch.append(None) diff --git a/vllm_ascend/multistream/ms_split.py b/vllm_ascend/multistream/ms_split.py index 636bc86e1d..63cfb71701 100644 --- a/vllm_ascend/multistream/ms_split.py +++ b/vllm_ascend/multistream/ms_split.py @@ -1,8 +1,11 @@ -import torch -from typing import Optional, Any, List +from typing import Any, List, Optional + import numpy as np +import torch from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig + + def compute_split_seq_index( query_lens: Optional[list[int]], attn_state: AscendAttentionState, @@ -31,6 +34,7 @@ def compute_split_seq_index( else: tokens = num_tokens // 2 return [tokens, tokens] + return [0, 0] def split_attn_tensor_type( input_tensor: torch.Tensor, @@ -49,6 +53,7 @@ def model_input_split_v1_mla_attn( ms_split_config: MSAttentionMetadataSplitConfig, ) -> List[Any]: assert 0 < ms_split_config.num_micro_batches < 3 + assert attn_metadata is not None [token_index, seq_index] = compute_split_seq_index(attn_metadata.query_lens, attn_metadata.attn_state, attn_metadata.num_decode_tokens) if token_index == 0 or seq_index == 0 or seq_index == len(attn_metadata.query_lens): @@ -88,7 +93,8 @@ def model_input_split_v1_mla_attn( attn_mask_pre = None attn_state_post = AscendAttentionState.ChunkedPrefill attn_mask_post = attn_metadata.attn_mask[token_index:, :max(seq_lens_post)].contiguous() - from vllm_ascend.attention.mla_v1 import AscendMLAPrefillMetadata, AscendMLADecodeMetadata + from vllm_ascend.attention.mla_v1 import (AscendMLADecodeMetadata, + AscendMLAPrefillMetadata) if num_prefills_pre > 0: # split metadata.prefill [input_positions_pre, input_positions_post] = split_attn_tensor_type(attn_metadata.prefill.input_positions, token_index - attn_metadata.num_decode_tokens) diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index c9cd9e7f82..f286f83c23 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -31,8 +31,6 @@ from vllm.model_executor.layers.quantization.base_config import \ QuantizationConfig -from vllm_ascend.multistream.base import MSEventKey -from vllm_ascend.multistream.metadata import MultiStreamStepMetadata, MultiStreamMetadata import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group From 423e9202199f94646d1c882c00842d74d40ba098 Mon Sep 17 00:00:00 2001 From: zhuohuan Date: Wed, 28 May 2025 20:50:30 +0800 Subject: [PATCH 04/11] [fix]: resolve format issues Signed-off-by: zhuohuan --- vllm_ascend/attention/mla_v1.py | 10 +- vllm_ascend/models/deepseek_v2.py | 77 +++------ vllm_ascend/multistream/base.py | 6 +- vllm_ascend/multistream/context.py | 9 +- vllm_ascend/multistream/decorator.py | 15 +- vllm_ascend/multistream/layers.py | 47 ++++-- vllm_ascend/multistream/metadata.py | 126 ++++++++------ vllm_ascend/multistream/ms_split.py | 240 +++++++++++++++------------ 8 files changed, 295 insertions(+), 235 deletions(-) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index e78ef4ca5a..40c7101f09 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -4,12 +4,6 @@ import numpy as np import torch import torch_npu -from vllm_ascend.attention.attention_v1 import AscendAttentionState -from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig -from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn -from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla -from vllm_ascend.worker.model_runner_v1 import NPUModelRunner - from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, AttentionMetadata, MLAAttentionImpl) @@ -19,10 +13,10 @@ from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.attention.attention_v1 import AscendAttentionState -from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla - from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn +from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla +from vllm_ascend.worker.model_runner_v1 import NPUModelRunner if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index 5dd7c2a007..70add15059 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -30,23 +30,9 @@ import torch import torch.distributed as dist import torch_npu -import vllm_ascend.envs as envs_ascend +import vllm.envs as envs from torch import nn from transformers import PretrainedConfig -from vllm_ascend.multistream.base import MSEventKey -from vllm_ascend.multistream.context import ( - advance_step_multistream_layer_context, get_multistream_comm_context, - get_multistream_layer_context, set_multistream_context) -from vllm_ascend.multistream.layers import (MultiStreamPostTransformerLayer, - MultiStreamPreTransformerLayer) -from vllm_ascend.multistream.metadata import (MultiStreamConfig, - MultiStreamStepMetadata, - make_multistream_metadata_ds) -from vllm_ascend.multistream.ms_split import compute_split_seq_index -from vllm_ascend.ops.fused_moe import AscendFusedMoE -from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod - -import vllm.envs as envs from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import (get_pp_group, @@ -80,19 +66,18 @@ from vllm.sequence import IntermediateTensors import vllm_ascend.envs as envs_ascend -from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.ops.fused_moe import AscendFusedMoE -from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod -from vllm_ascend.utils import dispose_tensor - -from vllm_ascend.multistream.context import ( - set_multistream_context, get_multistream_layer_context, - advance_step_multistream_layer_context, get_multistream_comm_context) -from vllm_ascend.multistream.layers import (MultiStreamPreTransformerLayer, - MultiStreamPostTransformerLayer) -from vllm_ascend.multistream.metadata import make_multistream_metadata_ds, MultiStreamStepMetadata, MultiStreamConfig from vllm_ascend.multistream.base import MSEventKey +from vllm_ascend.multistream.context import ( + advance_step_multistream_layer_context, get_multistream_comm_context, + get_multistream_layer_context, set_multistream_context) +from vllm_ascend.multistream.layers import (MultiStreamPostTransformerLayer, + MultiStreamPreTransformerLayer) +from vllm_ascend.multistream.metadata import (MultiStreamConfig, + MultiStreamStepMetadata, + make_multistream_metadata_ds) from vllm_ascend.multistream.ms_split import compute_split_seq_index +from vllm_ascend.ops.fused_moe import AscendFusedMoE +from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2 VLLM_ENABLE_DBO: bool = envs_ascend.VLLM_ENABLE_DBO @@ -734,7 +719,7 @@ def _forward_ms_layer( self, positions: List[torch.Tensor], hidden_states: List[torch.Tensor], - residual: Optional[List[torch.Tensor]], + residual: List[torch.Tensor], attn_metadata: List[AttentionMetadata], kv_cache: Optional[torch.Tensor] = None, is_prefill: bool = False, @@ -746,6 +731,7 @@ def _forward_ms_layer( assert isinstance(self.mlp, CustomDeepseekV2MoE) assert len(positions) == num_micro_batchs assert len(hidden_states) == num_micro_batchs + assert residual is not None assert attn_metadata is not None num_tokens = [] hidden_dims = [] @@ -1010,10 +996,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ["hidden_states", "residual"], config.hidden_size)) # tbo related members + self.multistream_config: Optional[MultiStreamConfig] = None if VLLM_ENABLE_DBO: self.multistream_config = MultiStreamConfig() - else: - self.multistream_config = None + self.use_mla = model_config.use_mla self.multistream_metadata = make_multistream_metadata_ds( start_layer=self.start_layer + self.first_k_dense_replace, @@ -1083,32 +1069,23 @@ def forward( def can_run_ms(self): # currently we only enable prefill overlap attn_metadata = get_forward_context().attn_metadata - # dp_metadata = get_forward_context().dp_metadata # profile run - if self.multistream_config is None or attn_metadata is None: + if attn_metadata is None or attn_metadata.num_prefills == 0: + return False + else: + [token_index, seq_index + ] = compute_split_seq_index(attn_metadata.query_lens, + attn_metadata.attn_state, + attn_metadata.num_decode_tokens) + if token_index == 0 or seq_index == 0 or seq_index == len( + attn_metadata.query_lens): + return False + + if self.multistream_config is None: return False # support mla attention and V1 engine at present if not self.use_mla or not envs.VLLM_USE_V1: return False - # disable decode dbo - if attn_metadata.num_prefills == 0: - return False - # check whether there is a dp rank that not use dual batch - ''' - num_microbatchs = self.multistream_config.num_micro_batches - if dp_metadata is not None: - for i in range(num_microbatchs): - cu_tokens = dp_metadata.cu_dbo_tokens_across_dp_cpu[i] - if torch.any(cu_tokens == 0).item(): - return False - ''' - [token_index, - seq_index] = compute_split_seq_index(attn_metadata.query_lens, - attn_metadata.attn_state, - attn_metadata.num_decode_tokens) - if token_index == 0 or seq_index == 0 or seq_index == len( - attn_metadata.query_lens): - return False # check whether the total tokens exceed the threshold if attn_metadata.num_actual_tokens < self.multistream_config.min_total_tokens_to_split: return False diff --git a/vllm_ascend/multistream/base.py b/vllm_ascend/multistream/base.py index f407fd2f76..9e3dd4b443 100644 --- a/vllm_ascend/multistream/base.py +++ b/vllm_ascend/multistream/base.py @@ -11,10 +11,12 @@ class MSEventKey(Enum): # events for MOE dispatch and combine MOE_BEFORE_COMM = 4 MOE_AFTER_COMM = 5 - # events for shared expert + # events for shared expert MOE_SE_COMM_FINISH = 6 MOE_SE_COMP_FINISH = 7 MOE_GATE_FINISH = 8 + + @dataclass class MSAttentionMetadataSplitConfig: """ @@ -25,4 +27,4 @@ class MSAttentionMetadataSplitConfig: # split micro batches only when total tokens >= min_total_tokens_to_split min_total_tokens_to_split: int = 256 # split micro batches only when prefill tokens >= min_prefill_tokens_to_split - min_prefill_tokens_to_split: int = 64 \ No newline at end of file + min_prefill_tokens_to_split: int = 64 diff --git a/vllm_ascend/multistream/context.py b/vllm_ascend/multistream/context.py index 21c9001464..9cda0649a8 100644 --- a/vllm_ascend/multistream/context.py +++ b/vllm_ascend/multistream/context.py @@ -9,7 +9,9 @@ _ms_metadata_context: Any = None _ms_attn_metadata_context: Any = None -def set_multistream_layer_context(start_layer: int, ms_metadata: Any, attn_metadata: Any): + +def set_multistream_layer_context(start_layer: int, ms_metadata: Any, + attn_metadata: Any): """ set multistream layer context before transformer layers """ @@ -18,6 +20,7 @@ def set_multistream_layer_context(start_layer: int, ms_metadata: Any, attn_metad _ms_metadata_context = ms_metadata _ms_attn_metadata_context = attn_metadata + def reset_multistream_layer_context(): """ reset multistream layer context @@ -27,12 +30,14 @@ def reset_multistream_layer_context(): _ms_metadata_context = None _ms_attn_metadata_context = None + def get_multistream_layer_context(): """ get multistream layer context """ return _ms_layer_index_context, _ms_metadata_context, _ms_attn_metadata_context + def advance_step_multistream_layer_context(): """ advance multistream layer index context @@ -45,9 +50,11 @@ def get_multistream_comm_context() -> Any: """Get the current comm forward context.""" return _ms_comm_context + def get_multistream_microbatch_context() -> int: return _cur_micro_batch_num + @contextmanager def set_multistream_context(context: Any, micro_batch_num: int): """A context manager that stores the current comm forward context, diff --git a/vllm_ascend/multistream/decorator.py b/vllm_ascend/multistream/decorator.py index 3c307b0fb4..381634a450 100644 --- a/vllm_ascend/multistream/decorator.py +++ b/vllm_ascend/multistream/decorator.py @@ -1,21 +1,28 @@ +from vllm.logger import init_logger + from .context import (get_multistream_layer_context, get_multistream_microbatch_context) -from vllm.logger import init_logger # TODO: move this part to vllm logger = init_logger(__name__) -# vllm v1 use get_forward_context to get the attn_metadata, + +# vllm v1 use get_forward_context to get the attn_metadata, # we can use this decorator to update the attn metadata def set_multistream_support(): + def decorator(func): + def wrapper(): context = func() - layer_index, ms_metadata, attn_metadata = get_multistream_layer_context() + layer_index, ms_metadata, attn_metadata = get_multistream_layer_context( + ) micro_batch_num = get_multistream_microbatch_context() if layer_index != -1 and micro_batch_num != -1: context.attn_metadata = attn_metadata[micro_batch_num] return context + return wrapper - return decorator \ No newline at end of file + + return decorator diff --git a/vllm_ascend/multistream/layers.py b/vllm_ascend/multistream/layers.py index 374e55a146..b940d09bd9 100644 --- a/vllm_ascend/multistream/layers.py +++ b/vllm_ascend/multistream/layers.py @@ -1,47 +1,62 @@ -from typing import List, Tuple, Union, Optional +from typing import List, Optional, Tuple, Union import torch -from vllm_ascend.multistream.context import (get_multistream_layer_context, - reset_multistream_layer_context, - set_multistream_layer_context) - from vllm.forward_context import get_forward_context from .base import MSEventKey +from .context import (get_multistream_layer_context, + reset_multistream_layer_context, + set_multistream_layer_context) from .metadata import MultiStreamMetadata # TODO: move this part to vllm class MultiStreamPreTransformerLayer(torch.nn.Module): + def __init__(self, multistream_metadata: MultiStreamMetadata): super().__init__() self.multistream_metadata = multistream_metadata - def forward(self, - intput_tensors: List[torch.Tensor],): + + def forward( + self, + intput_tensors: List[torch.Tensor], + ): attn_metadata = get_forward_context().attn_metadata if self.multistream_metadata is None or attn_metadata is None: set_multistream_layer_context(-1, None, None) return attn_metadata, intput_tensors # TODO add attn_metadata management - do_ms, attn_metadata, intput_tensors, _ = self.multistream_metadata.split_micro_batch(attn_metadata, intput_tensors) + do_ms, attn_metadata, intput_tensors, _ = self.multistream_metadata.split_micro_batch( + attn_metadata, intput_tensors) if do_ms: - set_multistream_layer_context(self.multistream_metadata.start_layer, self.multistream_metadata, attn_metadata) + set_multistream_layer_context( + self.multistream_metadata.start_layer, + self.multistream_metadata, attn_metadata) else: set_multistream_layer_context(-1, None, None) return attn_metadata, intput_tensors + + class MultiStreamPostTransformerLayer(torch.nn.Module): + def __init__(self, multistream_metadata: MultiStreamMetadata): super().__init__() self.multistream_metadata = multistream_metadata - def forward(self, input_tensors: Union[List[Tuple[torch.Tensor]], List[torch.Tensor], List[List[torch.Tensor]]], + + def forward(self, + input_tensors: Union[List[Tuple[torch.Tensor]], + List[torch.Tensor], + List[List[torch.Tensor]]], wait_layer_index: Optional[int] = None): if self.multistream_metadata is None or self.multistream_metadata.ms_config is None: return input_tensors - layer_index, ms_metadata, ms_attn_metadata = get_multistream_layer_context() + layer_index, ms_metadata, ms_attn_metadata = get_multistream_layer_context( + ) if layer_index >= 0: - true_wait_layer = self.multistream_metadata.end_layer-1 if wait_layer_index is None else wait_layer_index - self.multistream_metadata.try_wait_event(true_wait_layer, - self.multistream_metadata.ms_config.num_micro_batches-1, - MSEventKey.FFN_AR_FINISH) + true_wait_layer = self.multistream_metadata.end_layer - 1 if wait_layer_index is None else wait_layer_index + self.multistream_metadata.try_wait_event( + true_wait_layer, + self.multistream_metadata.ms_config.num_micro_batches - 1, + MSEventKey.FFN_AR_FINISH) reset_multistream_layer_context() - return self.multistream_metadata.merge_micro_batches(input_tensors) \ No newline at end of file + return self.multistream_metadata.merge_micro_batches(input_tensors) diff --git a/vllm_ascend/multistream/metadata.py b/vllm_ascend/multistream/metadata.py index d566a48ff9..9a721cfe16 100644 --- a/vllm_ascend/multistream/metadata.py +++ b/vllm_ascend/multistream/metadata.py @@ -2,22 +2,23 @@ from typing import Dict, List, Optional, Tuple, Union import torch - from vllm.attention.backends.abstract import AttentionMetadata from vllm.sequence import IntermediateTensors from .base import MSAttentionMetadataSplitConfig, MSEventKey -# TODO: move this part to vllm -def split_micro_batches_tensors(input_tensors, split_index: int, keys: Optional[List[str]] = None): +def split_micro_batches_tensors(input_tensors, + split_index: int, + keys: Optional[List[str]] = None): if isinstance(input_tensors, list): micro_batches = [] for tensor in input_tensors: if tensor is None: micro_batches.append([None, None]) else: - micro_batches.append([tensor[:split_index], tensor[split_index:]]) + micro_batches.append( + [tensor[:split_index], tensor[split_index:]]) return micro_batches elif isinstance(input_tensors, torch.Tensor): return [input_tensors[:split_index], input_tensors[split_index:]] @@ -35,12 +36,14 @@ def split_micro_batches_tensors(input_tensors, split_index: int, keys: Optional[ else: raise NotImplementedError + @dataclass class MultiStreamStepMetadata: comm_stream: torch.npu.Stream = None before_comm_event: torch.npu.Event = None after_comm_event: torch.npu.Event = None + @dataclass class MultiStreamConfig: """Controls the behavior of multi-stream models.""" @@ -49,6 +52,7 @@ class MultiStreamConfig: num_micro_batches: int = 2 imbalance_ratio: float = 0.1 + class MultiStreamMetadata: # direct stream calculate_stream = None @@ -58,66 +62,84 @@ class MultiStreamMetadata: ms_events: Dict[int, Dict[int, Dict[MSEventKey, torch.npu.Event]]] = {} # multi-stream-flag enable_multi_stream: bool = False - - def __init__(self, - calculate_stream: torch.npu.Stream, - communicate_stream: torch.npu.Stream, - start_layer: int, - end_layer: int, - event_keys: List[MSEventKey], - multistream_config: Optional[MultiStreamConfig], - causal_lm: bool = True, - ): + + def __init__( + self, + calculate_stream: torch.npu.Stream, + communicate_stream: torch.npu.Stream, + start_layer: int, + end_layer: int, + event_keys: List[MSEventKey], + multistream_config: Optional[MultiStreamConfig], + causal_lm: bool = True, + ): self.calculate_stream = calculate_stream self.communicate_stream = communicate_stream self.start_layer = start_layer self.end_layer = end_layer self.ms_config = multistream_config self.causal_lm = causal_lm - if self.ms_config is not None: - self._build_events(event_keys) - self._build_ms_split_config() + self._build_events(event_keys) + self._build_ms_split_config() def _build_events(self, event_keys): - for i in range(self.start_layer - 1, self.end_layer): - self.ms_events[i] = {} - for j in range(self.ms_config.num_micro_batches): - self.ms_events[i][j] = {} - for key in event_keys: - self.ms_events[i][j][key] = torch.npu.Event() + if self.ms_config is not None: + for i in range(self.start_layer - 1, self.end_layer): + self.ms_events[i] = {} + for j in range(self.ms_config.num_micro_batches): + self.ms_events[i][j] = {} + for key in event_keys: + self.ms_events[i][j][key] = torch.npu.Event() + def _build_ms_split_config(self): - self.ms_split_config = MSAttentionMetadataSplitConfig( - num_micro_batches=self.ms_config.num_micro_batches, - min_total_tokens_to_split=self.ms_config.min_total_tokens_to_split, - min_prefill_tokens_to_split=self.ms_config.min_prefill_tokens_to_split, - ) - def try_wait_event(self, layer_index: int, micro_batch_index: int, event_key: MSEventKey): + if self.ms_config is not None: + self.ms_split_config = MSAttentionMetadataSplitConfig( + num_micro_batches=self.ms_config.num_micro_batches, + min_total_tokens_to_split=self.ms_config. + min_total_tokens_to_split, + min_prefill_tokens_to_split=self.ms_config. + min_prefill_tokens_to_split, + ) + + def try_wait_event(self, layer_index: int, micro_batch_index: int, + event_key: MSEventKey): self.ms_events[layer_index][micro_batch_index][event_key].wait() - def try_record_event(self, layer_index: int, micro_batch_index: int, event_key: MSEventKey): + + def try_record_event(self, layer_index: int, micro_batch_index: int, + event_key: MSEventKey): self.ms_events[layer_index][micro_batch_index][event_key].record() - def split_micro_batch(self, - attn_metadata: "AttentionMetadata", - intput_tensors: List[torch.Tensor], - intermediate_tensors: Optional[IntermediateTensors] = None, - intermediate_tensors_keys: Optional[List[str]] = None, - ) -> Tuple[bool, - Union[AttentionMetadata, List[AttentionMetadata]], - Union[List[torch.Tensor], List[List[torch.Tensor]]], - Union[IntermediateTensors, List[IntermediateTensors]]]: - attn_metadata = attn_metadata.split_metadata_for_multistream(self.ms_split_config) + + def split_micro_batch( + self, + attn_metadata: "AttentionMetadata", + intput_tensors: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, + intermediate_tensors_keys: Optional[List[str]] = None, + ) -> Tuple[bool, Union[AttentionMetadata, List[AttentionMetadata]], Union[ + List[torch.Tensor], List[List[torch.Tensor]]], Union[ + IntermediateTensors, List[IntermediateTensors]]]: + attn_metadata = attn_metadata.split_metadata_for_multistream( + self.ms_split_config) if len(attn_metadata) == 1: - return False, attn_metadata[0], intput_tensors, intermediate_tensors + return False, attn_metadata[ + 0], intput_tensors, intermediate_tensors split_index = attn_metadata[0].slot_mapping.shape[0] - input_tensors = split_micro_batches_tensors(intput_tensors, split_index) + input_tensors = split_micro_batches_tensors(intput_tensors, + split_index) if intermediate_tensors is not None: - inter_tensors_list = split_micro_batches_tensors(intermediate_tensors.tensors, split_index, intermediate_tensors_keys) + inter_tensors_list = split_micro_batches_tensors( + intermediate_tensors.tensors, split_index, + intermediate_tensors_keys) intermediate_tensors = [ - IntermediateTensors(inter_tensors) for inter_tensors in inter_tensors_list + IntermediateTensors(inter_tensors) + for inter_tensors in inter_tensors_list ] return True, attn_metadata, input_tensors, intermediate_tensors - def merge_micro_batches(self, - input_tensors: Union[List[torch.Tensor], List[List[torch.Tensor]]] - ) -> List[torch.Tensor]: + + def merge_micro_batches( + self, input_tensors: Union[List[torch.Tensor], + List[List[torch.Tensor]]] + ) -> List[torch.Tensor]: if input_tensors is None or isinstance(input_tensors[0], torch.Tensor): return input_tensors batch: List[Optional[torch.Tensor]] = [] @@ -130,10 +152,10 @@ def merge_micro_batches(self, def make_multistream_metadata_ds( - start_layer: int, - end_layer: int, - causal_lm: bool = True, - multistream_config: Optional[MultiStreamConfig] = None, + start_layer: int, + end_layer: int, + causal_lm: bool = True, + multistream_config: Optional[MultiStreamConfig] = None, ): if multistream_config is None: return None @@ -156,4 +178,4 @@ def make_multistream_metadata_ds( multistream_config=multistream_config, event_keys=event_keylist, causal_lm=causal_lm, - ) \ No newline at end of file + ) diff --git a/vllm_ascend/multistream/ms_split.py b/vllm_ascend/multistream/ms_split.py index 63cfb71701..487fa151c9 100644 --- a/vllm_ascend/multistream/ms_split.py +++ b/vllm_ascend/multistream/ms_split.py @@ -2,16 +2,18 @@ import numpy as np import torch + from vllm_ascend.attention.attention_v1 import AscendAttentionState -from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig + +from .base import MSAttentionMetadataSplitConfig def compute_split_seq_index( - query_lens: Optional[list[int]], - attn_state: AscendAttentionState, - num_tokens: int, - imbalance_ratio: float = 0.1, - )->Optional[list[int]]: + query_lens: Optional[list[int]], + attn_state: AscendAttentionState, + num_tokens: int, + imbalance_ratio: float = 0.1, +) -> list[int]: if attn_state != AscendAttentionState.DecodeOnly: assert query_lens is not None total_tokens = sum(query_lens) @@ -20,59 +22,79 @@ def compute_split_seq_index( for value in query_lens: tokens += value split_index += 1 - if tokens >= total_tokens // 2 : + if tokens >= total_tokens // 2: # check the current split index - if abs(tokens - total_tokens // 2) < total_tokens * imbalance_ratio: + if abs(tokens - + total_tokens // 2) < total_tokens * imbalance_ratio: return [tokens, split_index] # check the previous split index - elif abs(tokens - total_tokens // 2 - value) < total_tokens * imbalance_ratio: - return [tokens-value, split_index-1] + elif abs(tokens - total_tokens // 2 - + value) < total_tokens * imbalance_ratio: + return [tokens - value, split_index - 1] # fail to split if it is imbalanced # TODO: split tokens in seq - else : + else: return [0, 0] else: - tokens = num_tokens // 2 + tokens = num_tokens // 2 return [tokens, tokens] return [0, 0] - + + def split_attn_tensor_type( - input_tensor: torch.Tensor, - index: int, - )-> List[torch.Tensor]: + input_tensor: torch.Tensor, + index: int, +) -> List[torch.Tensor]: return [input_tensor[:index], input_tensor[index:]] + + def split_attn_int_type( - var: int, - index: int, - )-> List[torch.Tensor]: - return [min(var,index), max(var-index, 0)] - + var: int, + index: int, +) -> List[torch.Tensor]: + return [min(var, index), max(var - index, 0)] + + def model_input_split_v1_mla_attn( - attn_metadata, - _metadata_cls, - ms_split_config: MSAttentionMetadataSplitConfig, - ) -> List[Any]: + attn_metadata, + _metadata_cls, + ms_split_config: MSAttentionMetadataSplitConfig, +) -> List[Any]: assert 0 < ms_split_config.num_micro_batches < 3 - assert attn_metadata is not None - [token_index, seq_index] = compute_split_seq_index(attn_metadata.query_lens, - attn_metadata.attn_state, attn_metadata.num_decode_tokens) - if token_index == 0 or seq_index == 0 or seq_index == len(attn_metadata.query_lens): + if attn_metadata is None: return [attn_metadata] - - query_start_loc_cpu = np.zeros(shape=(len(attn_metadata.query_lens) + 1,), dtype=int) + [token_index, + seq_index] = compute_split_seq_index(attn_metadata.query_lens, + attn_metadata.attn_state, + attn_metadata.num_decode_tokens) + if token_index == 0 or seq_index == 0 or seq_index == len( + attn_metadata.query_lens): + return [attn_metadata] + + query_start_loc_cpu = np.zeros(shape=(len(attn_metadata.query_lens) + 1, ), + dtype=int) np.cumsum(attn_metadata.query_lens, out=query_start_loc_cpu[1:]) if attn_metadata.num_prefills > 0: - prefill_query_start_loc = np.zeros(shape=(len(attn_metadata.prefill.query_lens) + 1,), dtype=int) - np.cumsum(attn_metadata.prefill.query_lens, out=prefill_query_start_loc[1:]) - + prefill_query_start_loc = np.zeros( + shape=(len(attn_metadata.prefill.query_lens) + 1, ), dtype=int) + np.cumsum(attn_metadata.prefill.query_lens, + out=prefill_query_start_loc[1:]) + # split attn metadata - [slot_mapping_pre, slot_mapping_post] = split_attn_tensor_type(attn_metadata.slot_mapping,token_index) - [num_decodes_pre, num_decodes_post] = split_attn_int_type(attn_metadata.num_decodes, seq_index) - [num_decode_tokens_pre, num_decode_tokens_post] = split_attn_int_type(attn_metadata.num_decode_tokens, token_index) - [num_prefills_pre, num_prefills_post] = split_attn_int_type(attn_metadata.num_prefills, max(0,seq_index-attn_metadata.num_decodes)) - seq_lens = attn_metadata.prefill.seq_lens if attn_metadata.num_prefills>0 else attn_metadata.decode.seq_lens - [seq_lens_pre, seq_lens_post] = split_attn_tensor_type(seq_lens,seq_index) - + [slot_mapping_pre, + slot_mapping_post] = split_attn_tensor_type(attn_metadata.slot_mapping, + token_index) + [num_decodes_pre, + num_decodes_post] = split_attn_int_type(attn_metadata.num_decodes, + seq_index) + [num_decode_tokens_pre, num_decode_tokens_post + ] = split_attn_int_type(attn_metadata.num_decode_tokens, token_index) + [num_prefills_pre, num_prefills_post + ] = split_attn_int_type(attn_metadata.num_prefills, + max(0, seq_index - attn_metadata.num_decodes)) + seq_lens = attn_metadata.prefill.seq_lens if attn_metadata.num_prefills > 0 else attn_metadata.decode.seq_lens + [seq_lens_pre, seq_lens_post] = split_attn_tensor_type(seq_lens, seq_index) + if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache or attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit: # the attn_mla kernel in torch npu only accept 128*128 attn mask attn_mask_pre = attn_mask_post = attn_metadata.attn_mask @@ -81,98 +103,112 @@ def model_input_split_v1_mla_attn( # should be none in decode only state attn_mask_pre = attn_mask_post = attn_metadata.attn_mask attn_state_pre = attn_state_post = AscendAttentionState.DecodeOnly - else : + else: # chunked prefill if num_prefills_pre > 0: attn_state_pre = attn_state_post = AscendAttentionState.ChunkedPrefill - attn_mask_pre = attn_metadata.attn_mask[:token_index, :max(seq_lens_pre)].contiguous() + attn_mask_pre = attn_metadata.attn_mask[:token_index, :max( + seq_lens_pre)].contiguous() attn_state_post = AscendAttentionState.ChunkedPrefill - attn_mask_post = attn_metadata.attn_mask[token_index:, :max(seq_lens_post)].contiguous() + attn_mask_post = attn_metadata.attn_mask[ + token_index:, :max(seq_lens_post)].contiguous() else: attn_state_pre = AscendAttentionState.DecodeOnly attn_mask_pre = None attn_state_post = AscendAttentionState.ChunkedPrefill - attn_mask_post = attn_metadata.attn_mask[token_index:, :max(seq_lens_post)].contiguous() + attn_mask_post = attn_metadata.attn_mask[ + token_index:, :max(seq_lens_post)].contiguous() from vllm_ascend.attention.mla_v1 import (AscendMLADecodeMetadata, AscendMLAPrefillMetadata) if num_prefills_pre > 0: # split metadata.prefill - [input_positions_pre, input_positions_post] = split_attn_tensor_type(attn_metadata.prefill.input_positions, token_index - attn_metadata.num_decode_tokens) - [block_tables_pre, block_tables_post] = split_attn_tensor_type(attn_metadata.prefill.block_table, seq_index- attn_metadata.num_decodes) - [prefill_query_lens_pre, prefill_query_lens_post] = split_attn_tensor_type(attn_metadata.prefill.query_lens,seq_index- attn_metadata.num_decodes) + [input_positions_pre, input_positions_post] = split_attn_tensor_type( + attn_metadata.prefill.input_positions, + token_index - attn_metadata.num_decode_tokens) + [block_tables_pre, block_tables_post + ] = split_attn_tensor_type(attn_metadata.prefill.block_table, + seq_index - attn_metadata.num_decodes) + [prefill_query_lens_pre, prefill_query_lens_post + ] = split_attn_tensor_type(attn_metadata.prefill.query_lens, + seq_index - attn_metadata.num_decodes) context_len_pre = seq_lens_pre[attn_metadata.num_decodes:] context_len_post = seq_lens_post prefill_max_query_len_pre = max(prefill_query_lens_pre) prefill_max_query_len_post = max(prefill_query_lens_post) prefill_pre = AscendMLAPrefillMetadata( - attn_mask=attn_mask_pre, - query_lens=prefill_query_lens_pre, - seq_lens=seq_lens_pre, - input_positions=input_positions_pre, - context_lens=context_len_pre, - block_table=block_tables_pre, - max_query_len=prefill_max_query_len_pre, - max_seq_lens=context_len_pre.max().item(), - ) + attn_mask=attn_mask_pre, + query_lens=prefill_query_lens_pre, + seq_lens=seq_lens_pre, + input_positions=input_positions_pre, + context_lens=context_len_pre, + block_table=block_tables_pre, + max_query_len=prefill_max_query_len_pre, + max_seq_lens=context_len_pre.max().item(), + ) prefill_post = AscendMLAPrefillMetadata( - attn_mask=attn_mask_post, - query_lens=prefill_query_lens_post, - seq_lens=seq_lens_post, - input_positions=input_positions_post, - context_lens=context_len_post, - block_table=block_tables_post, - max_query_len=prefill_max_query_len_post, - max_seq_lens=context_len_post.max().item(), - ) + attn_mask=attn_mask_post, + query_lens=prefill_query_lens_post, + seq_lens=seq_lens_post, + input_positions=input_positions_post, + context_lens=context_len_post, + block_table=block_tables_post, + max_query_len=prefill_max_query_len_post, + max_seq_lens=context_len_post.max().item(), + ) decode_pre = attn_metadata.decode decode_post = None - else : + else: # prefill is None, split metadata.decode - [input_positions_pre, input_positions_post] = split_attn_tensor_type(attn_metadata.decode.input_positions, token_index ) - [block_tables_pre, block_tables_post] = split_attn_tensor_type(attn_metadata.decode.block_table, seq_index) - [decode_seq_lens_pre, decode_seq_lens_post] = split_attn_tensor_type(seq_lens,seq_index) + [input_positions_pre, input_positions_post + ] = split_attn_tensor_type(attn_metadata.decode.input_positions, + token_index) + [block_tables_pre, block_tables_post + ] = split_attn_tensor_type(attn_metadata.decode.block_table, + seq_index) + [decode_seq_lens_pre, + decode_seq_lens_post] = split_attn_tensor_type(seq_lens, seq_index) decode_pre = AscendMLADecodeMetadata( - input_positions=input_positions_pre, - block_table=block_tables_pre, - seq_lens=decode_seq_lens_pre, - max_seq_lens=max(decode_seq_lens_pre), - seq_lens_list=decode_seq_lens_pre.tolist(), + input_positions=input_positions_pre, + block_table=block_tables_pre, + seq_lens=decode_seq_lens_pre, + max_seq_lens=max(decode_seq_lens_pre), + seq_lens_list=decode_seq_lens_pre.tolist(), ) decode_post = AscendMLADecodeMetadata( - input_positions=input_positions_post, - block_table=block_tables_post, - seq_lens=decode_seq_lens_post, - max_seq_lens=max(decode_seq_lens_post), - seq_lens_list=decode_seq_lens_post.tolist(), + input_positions=input_positions_post, + block_table=block_tables_post, + seq_lens=decode_seq_lens_post, + max_seq_lens=max(decode_seq_lens_post), + seq_lens_list=decode_seq_lens_post.tolist(), ) prefill_pre = None prefill_post = attn_metadata.prefill # construct metadata from vllm_ascend.attention.mla_v1 import AscendMLAPrefillMetadata attention_metadata_pre = _metadata_cls( - num_actual_tokens=token_index, - num_input_tokens=token_index, - head_dim=attn_metadata.head_dim, - slot_mapping=slot_mapping_pre, - num_decodes=num_decodes_pre, - num_prefills=num_prefills_pre, - num_decode_tokens=num_decode_tokens_pre, - attn_state = attn_state_pre, - attn_mask = attn_mask_pre, - prefill=prefill_pre, - decode=decode_pre, + num_actual_tokens=token_index, + num_input_tokens=token_index, + head_dim=attn_metadata.head_dim, + slot_mapping=slot_mapping_pre, + num_decodes=num_decodes_pre, + num_prefills=num_prefills_pre, + num_decode_tokens=num_decode_tokens_pre, + attn_state=attn_state_pre, + attn_mask=attn_mask_pre, + prefill=prefill_pre, + decode=decode_pre, ) attention_metadata_post = _metadata_cls( - num_actual_tokens=attn_metadata.num_actual_tokens - token_index, - num_input_tokens = attn_metadata.num_input_tokens - token_index, - head_dim = attn_metadata.head_dim, - slot_mapping=slot_mapping_post, - num_decodes=num_decodes_post, - num_prefills=num_prefills_post, - num_decode_tokens=num_decode_tokens_post, - attn_mask = attn_mask_post, - attn_state = attn_state_post, - prefill=prefill_post, - decode=decode_post, + num_actual_tokens=attn_metadata.num_actual_tokens - token_index, + num_input_tokens=attn_metadata.num_input_tokens - token_index, + head_dim=attn_metadata.head_dim, + slot_mapping=slot_mapping_post, + num_decodes=num_decodes_post, + num_prefills=num_prefills_post, + num_decode_tokens=num_decode_tokens_post, + attn_mask=attn_mask_post, + attn_state=attn_state_post, + prefill=prefill_post, + decode=decode_post, ) return [attention_metadata_pre, attention_metadata_post] From e5eed0f585b0e050de15026b5bd25ca2a311a6eb Mon Sep 17 00:00:00 2001 From: zhuohuan Date: Thu, 29 May 2025 16:16:24 +0800 Subject: [PATCH 05/11] [fix]: fix accuracy issues for dbo in deepseek Signed-off-by: zhuohuan --- vllm_ascend/models/deepseek_v2.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index 70add15059..85a74d234a 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -384,10 +384,8 @@ def _forward_ms_op_gate( def _forward_ms_op_tp_allgather( self, hidden_states: torch.Tensor, - shared_output: torch.Tensor, chunk_hidden_states: torch.Tensor, num_tokens: int = 0, - hidden_dim: int = 0, ): if self.tp_size > 1: @@ -414,11 +412,6 @@ def _forward_ms_op_tp_allgather( else: final_hidden_states = hidden_states - if shared_output is not None: - final_hidden_states = final_hidden_states + shared_output - final_hidden_states = final_hidden_states.view( - num_tokens, hidden_dim) - return final_hidden_states @@ -793,7 +786,7 @@ def _forward_ms_layer( num_token, hidden_dim = hidden_states[i].shape hidden_states[i] = hidden_states[i].view(-1, hidden_dim) - #num_tokens.append(num_token) + num_tokens.append(num_token) hidden_dims.append(hidden_dim) if self.mlp.n_shared_experts is not None: # TODO: we can move shared expert computation into next block if reduce results is false @@ -829,7 +822,6 @@ def _forward_ms_layer( if padded_num_tokens > 0: hidden_states[i] = nn.functional.pad( hidden_states[i], (0, 0, 0, padded_num_tokens)) - num_tokens.append(padded_num_tokens) chunk_hidden_state = torch.tensor_split(hidden_states[i], self.mlp.tp_size, dim=0) @@ -887,10 +879,14 @@ def _forward_ms_layer( ) with set_multistream_context(context, i): hidden_states[i] = self.mlp._forward_ms_op_tp_allgather( - hidden_states[i], shared_outputs[i], - chunk_hidden_states[i], num_tokens[i], hidden_dims[i]) + hidden_states[i], chunk_hidden_states[i], + padded_num_tokens) with torch.npu.stream(ms_metadata.communicate_stream): # last + if shared_output is not None: + hidden_states[i] = hidden_states[i] + shared_outputs[i] + hidden_states[i] = hidden_states[i].view( + num_tokens[i], hidden_dims[i]) if isinstance(self.mlp, CustomDeepseekV2MLP ) and hidden_states[i].dtype == torch.float16: # Fix FP16 overflow From e43a6185ad8f5548a5392f0ba1624f37ea58e258 Mon Sep 17 00:00:00 2001 From: zhuohuan Date: Fri, 30 May 2025 17:11:00 +0800 Subject: [PATCH 06/11] [fix]: add e2e test for dbo Signed-off-by: zhuohuan --- .../test_offline_inference_with_dbo.py | 48 +++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 tests/multicard/test_offline_inference_with_dbo.py diff --git a/tests/multicard/test_offline_inference_with_dbo.py b/tests/multicard/test_offline_inference_with_dbo.py new file mode 100644 index 0000000000..00dfea093f --- /dev/null +++ b/tests/multicard/test_offline_inference_with_dbo.py @@ -0,0 +1,48 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py +# +"""Compare the short outputs of HF and vLLM when using greedy sampling. + +Run `pytest tests/test_offline_inference.py`. +""" +import os + +import vllm # noqa: F401 + +from tests.conftest import VllmRunner + +os.environ["VLLM_USE_V1"] = "1" +os.environ["VLLM_ENABLE_DBO"] = "1" + + +def test_deepseek_model_with_dbo(): + example_prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] * 100 + dtype = "half" + max_tokens = 5 + with VllmRunner( + "deepseek-ai/DeepSeek-V2-Lite", + dtype=dtype, + tensor_parallel_size=4, + distributed_executor_backend="mp", + ) as vllm_model: + vllm_model.generate_greedy(example_prompts, max_tokens) From 142402e765f4fb2e35a4cf1c8d659dbb20344322 Mon Sep 17 00:00:00 2001 From: zhuohuan Date: Fri, 30 May 2025 21:49:12 +0800 Subject: [PATCH 07/11] [feat]: support v0.9.0 modification of mla attn metadata Signed-off-by: zhuohuan --- .../test_offline_inference_distributed.py | 13 ++++- .../test_offline_inference_with_dbo.py | 48 ------------------- vllm_ascend/attention/mla_v1.py | 2 - vllm_ascend/models/deepseek_v2.py | 10 ++-- vllm_ascend/multistream/ms_split.py | 29 +++++++++++ 5 files changed, 46 insertions(+), 56 deletions(-) delete mode 100644 tests/multicard/test_offline_inference_with_dbo.py diff --git a/tests/multicard/test_offline_inference_distributed.py b/tests/multicard/test_offline_inference_distributed.py index dd8da8c37e..3bfcf33be0 100644 --- a/tests/multicard/test_offline_inference_distributed.py +++ b/tests/multicard/test_offline_inference_distributed.py @@ -74,10 +74,21 @@ def test_models_distributed_topk() -> None: top_k=50, top_p=0.9) + +def test_deepseek_model_with_dbo(): + os.environ["VLLM_ENABLE_DBO"] = "1" + example_prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] * 10 + dtype = "half" + max_tokens = 5 with VllmRunner( "deepseek-ai/DeepSeek-V2-Lite", dtype=dtype, tensor_parallel_size=4, distributed_executor_backend="mp", ) as vllm_model: - vllm_model.generate(example_prompts, sampling_params) + vllm_model.generate_greedy(example_prompts, max_tokens) diff --git a/tests/multicard/test_offline_inference_with_dbo.py b/tests/multicard/test_offline_inference_with_dbo.py deleted file mode 100644 index 00dfea093f..0000000000 --- a/tests/multicard/test_offline_inference_with_dbo.py +++ /dev/null @@ -1,48 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2023 The vLLM team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# This file is a part of the vllm-ascend project. -# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py -# -"""Compare the short outputs of HF and vLLM when using greedy sampling. - -Run `pytest tests/test_offline_inference.py`. -""" -import os - -import vllm # noqa: F401 - -from tests.conftest import VllmRunner - -os.environ["VLLM_USE_V1"] = "1" -os.environ["VLLM_ENABLE_DBO"] = "1" - - -def test_deepseek_model_with_dbo(): - example_prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] * 100 - dtype = "half" - max_tokens = 5 - with VllmRunner( - "deepseek-ai/DeepSeek-V2-Lite", - dtype=dtype, - tensor_parallel_size=4, - distributed_executor_backend="mp", - ) as vllm_model: - vllm_model.generate_greedy(example_prompts, max_tokens) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 40c7101f09..9713fefc4d 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -16,7 +16,6 @@ from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla -from vllm_ascend.worker.model_runner_v1 import NPUModelRunner if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -402,7 +401,6 @@ def build( return self.metadata_cls( # type: ignore num_actual_tokens=num_actual_tokens, query_lens=query_lens.tolist(), - seq_lens=seq_lens, slot_mapping=slot_mapping, head_dim=self.runner.model_config.get_head_size(), num_decodes=self._num_decodes, diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index 85a74d234a..47de966f2a 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -78,6 +78,7 @@ from vllm_ascend.multistream.ms_split import compute_split_seq_index from vllm_ascend.ops.fused_moe import AscendFusedMoE from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod +from vllm_ascend.utils import dispose_tensor VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2 VLLM_ENABLE_DBO: bool = envs_ascend.VLLM_ENABLE_DBO @@ -1063,9 +1064,11 @@ def forward( return hidden_states def can_run_ms(self): - # currently we only enable prefill overlap attn_metadata = get_forward_context().attn_metadata - # profile run + # support mla attention and V1 engine at present + if not self.use_mla or not envs.VLLM_USE_V1: + return False + # enable prefill overlap if attn_metadata is None or attn_metadata.num_prefills == 0: return False else: @@ -1079,9 +1082,6 @@ def can_run_ms(self): if self.multistream_config is None: return False - # support mla attention and V1 engine at present - if not self.use_mla or not envs.VLLM_USE_V1: - return False # check whether the total tokens exceed the threshold if attn_metadata.num_actual_tokens < self.multistream_config.min_total_tokens_to_split: return False diff --git a/vllm_ascend/multistream/ms_split.py b/vllm_ascend/multistream/ms_split.py index 487fa151c9..3af6337e47 100644 --- a/vllm_ascend/multistream/ms_split.py +++ b/vllm_ascend/multistream/ms_split.py @@ -1,3 +1,4 @@ +from copy import deepcopy from typing import Any, List, Optional import numpy as np @@ -95,6 +96,14 @@ def model_input_split_v1_mla_attn( seq_lens = attn_metadata.prefill.seq_lens if attn_metadata.num_prefills > 0 else attn_metadata.decode.seq_lens [seq_lens_pre, seq_lens_post] = split_attn_tensor_type(seq_lens, seq_index) + query_start_loc_pre = attn_metadata.query_start_loc[:seq_index + 1] + query_start_loc_post = deepcopy( + attn_metadata.query_start_loc[seq_index:] + ) - attn_metadata.query_start_loc[seq_index] + [block_table_pre, + block_table_post] = split_attn_tensor_type(attn_metadata.block_tables, + seq_index) + if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache or attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit: # the attn_mla kernel in torch npu only accept 128*128 attn mask attn_mask_pre = attn_mask_post = attn_metadata.attn_mask @@ -131,6 +140,18 @@ def model_input_split_v1_mla_attn( [prefill_query_lens_pre, prefill_query_lens_post ] = split_attn_tensor_type(attn_metadata.prefill.query_lens, seq_index - attn_metadata.num_decodes) + prefill_query_start_loc_pre = attn_metadata.prefill.query_start_loc[: + seq_index + + + 1 - + attn_metadata + . + num_decodes] + prefill_query_start_loc_post = deepcopy( + attn_metadata.prefill.query_start_loc[seq_index - + attn_metadata.num_decodes:] + ) - attn_metadata.prefill.query_start_loc[seq_index - + attn_metadata.num_decodes] context_len_pre = seq_lens_pre[attn_metadata.num_decodes:] context_len_post = seq_lens_post prefill_max_query_len_pre = max(prefill_query_lens_pre) @@ -139,6 +160,7 @@ def model_input_split_v1_mla_attn( attn_mask=attn_mask_pre, query_lens=prefill_query_lens_pre, seq_lens=seq_lens_pre, + query_start_loc=prefill_query_start_loc_pre, input_positions=input_positions_pre, context_lens=context_len_pre, block_table=block_tables_pre, @@ -149,6 +171,7 @@ def model_input_split_v1_mla_attn( attn_mask=attn_mask_post, query_lens=prefill_query_lens_post, seq_lens=seq_lens_post, + query_start_loc=prefill_query_start_loc_post, input_positions=input_positions_post, context_lens=context_len_post, block_table=block_tables_post, @@ -190,6 +213,9 @@ def model_input_split_v1_mla_attn( num_input_tokens=token_index, head_dim=attn_metadata.head_dim, slot_mapping=slot_mapping_pre, + seq_lens=seq_lens_pre, + query_start_loc=query_start_loc_pre, + block_tables=block_table_pre, num_decodes=num_decodes_pre, num_prefills=num_prefills_pre, num_decode_tokens=num_decode_tokens_pre, @@ -203,6 +229,9 @@ def model_input_split_v1_mla_attn( num_input_tokens=attn_metadata.num_input_tokens - token_index, head_dim=attn_metadata.head_dim, slot_mapping=slot_mapping_post, + seq_lens=seq_lens_post, + query_start_loc=query_start_loc_post, + block_tables=block_table_post, num_decodes=num_decodes_post, num_prefills=num_prefills_post, num_decode_tokens=num_decode_tokens_post, From 2e046e85e349bd8cd189752fcd29d839afa6b9b0 Mon Sep 17 00:00:00 2001 From: zhuohuan Date: Tue, 3 Jun 2025 17:47:28 +0800 Subject: [PATCH 08/11] [fix]: optimize the dbo execution and fix minor issues Signed-off-by: zhuohuan --- vllm_ascend/models/deepseek_v2.py | 92 +++++++++++----------------- vllm_ascend/multistream/base.py | 1 - vllm_ascend/multistream/context.py | 2 - vllm_ascend/multistream/decorator.py | 2 - vllm_ascend/multistream/layers.py | 1 - vllm_ascend/multistream/metadata.py | 17 ++--- 6 files changed, 45 insertions(+), 70 deletions(-) diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index 47de966f2a..c69ccf5a14 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -732,13 +732,12 @@ def _forward_ms_layer( shared_outputs = [] router_logits = [] chunk_hidden_states = [] - ''' block 1 : attention - block 2 : attn tp communication, currently we switch to the comm stream - in tensor_model_parallel_all_reduce; - the attn computation of microbatch 1 can be overlapped with the moe - communication in the previous layer, and the attn computation of microbatch - 2 can be overlapped with the attn communication of microbatch 1 - ''' + + # block 1 : attention + # block 2 : attn tp communication + # the attn computation of microbatch 1 can be overlapped with the moe + # communication in the previous layer, and the attn computation of microbatch 2 + # can be overlapped with the attn communication of microbatch 1 for i in range(num_micro_batchs): # wait last layer moe finishing communication ms_metadata.try_wait_event(layer_index - 1, i, @@ -765,10 +764,10 @@ def _forward_ms_layer( hidden_states[i], residual[i] = self._forward_ms_op_attn( positions[i], hidden_states[i], residual[i], kv_cache, attn_metadata[i]) - ''' block 3 : shared experts - if there is an allreduce ops in shared expert, we can overlap it with the computation of the - shared expert for next microbatch or moe gating - ''' + + # block 3 : shared experts + # if there is an allreduce ops in shared expert, we can overlap it with the computation of the + # shared expert for next microbatch or moe gating for i in range(num_micro_batchs): ms_metadata.try_wait_event(layer_index, i, MSEventKey.ATTN_AR_FINISH) @@ -797,7 +796,6 @@ def _forward_ms_layer( # block 4 : moe for i in range(num_micro_batchs): - #ms_metadata.try_wait_event(layer_index, i, MSEventKey.MOE_SE_COMM_FINISH) # when profile runs, force experts to load balanced tokens # to avoid high memory consumption on a single rank. # TODO: need a better flag to indicate whether in profile run or not. @@ -810,13 +808,6 @@ def _forward_ms_layer( enable_force_load_balance = False if self.mlp.tp_size > 1: - #if num_tokens[i] < self.mlp.tp_size: - # target_size = self.mlp.tp_size - # new_hidden_states = torch.empty([target_size, hidden_dims[i]], - # dtype=hidden_states[i].dtype, - # device=hidden_states[i].device) - # new_hidden_states[:num_tokens[i]] = hidden_states[i] - # hidden_states[i] = new_hidden_states num_token, _ = hidden_states[i].shape padded_num_tokens = (self.mlp.tp_size - num_token % self.mlp.tp_size) % self.mlp.tp_size @@ -839,18 +830,12 @@ def _forward_ms_layer( else: real_top_k = self.mlp.experts.top_k - if VLLM_ENABLE_MC2 and not is_prefill: - ... - hidden_states[i] = self.mlp.experts._forward_ms_fused_moe_comp( local_hidden_states, router_logits[i], is_prefill, real_top_k, enable_force_load_balance) - if VLLM_ENABLE_MC2 and not is_prefill: - ... - ''' the following kernels will be submitted to the comm stream to overlap the computation of the - moe computation of next microbatch and the attn computation of next layer - ''' + # the following kernels will be submitted to the comm stream to overlap the computation of the + # moe computation of next microbatch and the attn computation of next layer context = MultiStreamStepMetadata( comm_stream=ms_metadata.communicate_stream, before_comm_event=ms_metadata.ms_events[layer_index][i][ @@ -860,7 +845,6 @@ def _forward_ms_layer( ) context.before_comm_event.record() with torch.npu.stream(ms_metadata.communicate_stream): - #with set_multistream_context(context, i): context.before_comm_event.wait() if self.mlp.experts.reduce_results and ( self.mlp.experts.tp_size > 1 @@ -868,7 +852,7 @@ def _forward_ms_layer( hidden_states[i] = tensor_model_parallel_all_reduce( hidden_states[i]) context.after_comm_event.record() - # check here + hidden_states[ i] = hidden_states[i] * self.mlp.routed_scaling_factor context = MultiStreamStepMetadata( @@ -993,21 +977,19 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ["hidden_states", "residual"], config.hidden_size)) # tbo related members - self.multistream_config: Optional[MultiStreamConfig] = None if VLLM_ENABLE_DBO: + self.use_mla = model_config.use_mla self.multistream_config = MultiStreamConfig() - - self.use_mla = model_config.use_mla - self.multistream_metadata = make_multistream_metadata_ds( - start_layer=self.start_layer + self.first_k_dense_replace, - end_layer=self.end_layer, - causal_lm=getattr(config, "causal_lm", True), - multistream_config=self.multistream_config, - ) - self.ms_pre_layer = MultiStreamPreTransformerLayer( - self.multistream_metadata) - self.ms_post_layer = MultiStreamPostTransformerLayer( - self.multistream_metadata) + multistream_metadata = make_multistream_metadata_ds( + start_layer=self.start_layer + self.first_k_dense_replace, + end_layer=self.end_layer, + causal_lm=getattr(config, "causal_lm", True), + multistream_config=self.multistream_config, + ) + self.ms_pre_layer = MultiStreamPreTransformerLayer( + multistream_metadata) + self.ms_post_layer = MultiStreamPostTransformerLayer( + multistream_metadata) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -1032,11 +1014,10 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - num_normal_layers = (self.first_k_dense_replace - if self.multistream_config is not None + num_normal_layers = (self.first_k_dense_replace if VLLM_ENABLE_DBO and self.can_run_ms() else self.end_layer - self.start_layer) - # if we enable multistream/dbo, only process dense layers here + for i in range(self.start_layer, self.start_layer + num_normal_layers): layer = self.layers[i] hidden_states, residual = layer( @@ -1046,13 +1027,15 @@ def forward( attn_metadata) moe_start_layer = self.start_layer + num_normal_layers - hidden_states, residual = self._forward_ms_layers( - positions=positions, - hidden_states=hidden_states, - residual=residual, - moe_start_layer=moe_start_layer, - kv_caches=kv_caches, - ) + if moe_start_layer != self.end_layer: + # if we enable multistream/dbo, process sparse layers here + hidden_states, residual = self._forward_ms_layers( + positions=positions, + hidden_states=hidden_states, + residual=residual, + moe_start_layer=moe_start_layer, + kv_caches=kv_caches, + ) if not get_pp_group().is_last_rank: return IntermediateTensors({ @@ -1079,11 +1062,8 @@ def can_run_ms(self): if token_index == 0 or seq_index == 0 or seq_index == len( attn_metadata.query_lens): return False - - if self.multistream_config is None: - return False # check whether the total tokens exceed the threshold - if attn_metadata.num_actual_tokens < self.multistream_config.min_total_tokens_to_split: + if self.multistream_config is None or attn_metadata.num_actual_tokens < self.multistream_config.min_total_tokens_to_split: return False return True diff --git a/vllm_ascend/multistream/base.py b/vllm_ascend/multistream/base.py index 9e3dd4b443..fba58b460e 100644 --- a/vllm_ascend/multistream/base.py +++ b/vllm_ascend/multistream/base.py @@ -2,7 +2,6 @@ from enum import Enum -# TODO: move this part to vllm class MSEventKey(Enum): ATTN_COM_FINISH = 0 ATTN_AR_FINISH = 1 diff --git a/vllm_ascend/multistream/context.py b/vllm_ascend/multistream/context.py index 9cda0649a8..afc8ba0235 100644 --- a/vllm_ascend/multistream/context.py +++ b/vllm_ascend/multistream/context.py @@ -1,8 +1,6 @@ from contextlib import contextmanager from typing import Any -# TODO: move this part to vllm - _ms_comm_context: Any = None _cur_micro_batch_num: int = -1 _ms_layer_index_context: int = -1 diff --git a/vllm_ascend/multistream/decorator.py b/vllm_ascend/multistream/decorator.py index 381634a450..6c7f16aeb8 100644 --- a/vllm_ascend/multistream/decorator.py +++ b/vllm_ascend/multistream/decorator.py @@ -3,8 +3,6 @@ from .context import (get_multistream_layer_context, get_multistream_microbatch_context) -# TODO: move this part to vllm - logger = init_logger(__name__) diff --git a/vllm_ascend/multistream/layers.py b/vllm_ascend/multistream/layers.py index b940d09bd9..c5273bce73 100644 --- a/vllm_ascend/multistream/layers.py +++ b/vllm_ascend/multistream/layers.py @@ -10,7 +10,6 @@ from .metadata import MultiStreamMetadata -# TODO: move this part to vllm class MultiStreamPreTransformerLayer(torch.nn.Module): def __init__(self, multistream_metadata: MultiStreamMetadata): diff --git a/vllm_ascend/multistream/metadata.py b/vllm_ascend/multistream/metadata.py index 9a721cfe16..b521d3f85f 100644 --- a/vllm_ascend/multistream/metadata.py +++ b/vllm_ascend/multistream/metadata.py @@ -2,9 +2,10 @@ from typing import Dict, List, Optional, Tuple, Union import torch -from vllm.attention.backends.abstract import AttentionMetadata from vllm.sequence import IntermediateTensors +from vllm_ascend.attention.mla_v1 import AscendMLAMetadata + from .base import MSAttentionMetadataSplitConfig, MSEventKey @@ -111,19 +112,19 @@ def try_record_event(self, layer_index: int, micro_batch_index: int, def split_micro_batch( self, - attn_metadata: "AttentionMetadata", + attn_metadata: "AscendMLAMetadata", intput_tensors: List[torch.Tensor], intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors_keys: Optional[List[str]] = None, - ) -> Tuple[bool, Union[AttentionMetadata, List[AttentionMetadata]], Union[ + ) -> Tuple[bool, Union[AscendMLAMetadata, List[AscendMLAMetadata]], Union[ List[torch.Tensor], List[List[torch.Tensor]]], Union[ IntermediateTensors, List[IntermediateTensors]]]: - attn_metadata = attn_metadata.split_metadata_for_multistream( + attn_metadata_list = attn_metadata.split_metadata_for_multistream( self.ms_split_config) - if len(attn_metadata) == 1: - return False, attn_metadata[ + if len(attn_metadata_list) == 1: + return False, attn_metadata_list[ 0], intput_tensors, intermediate_tensors - split_index = attn_metadata[0].slot_mapping.shape[0] + split_index = attn_metadata_list[0].slot_mapping.shape[0] input_tensors = split_micro_batches_tensors(intput_tensors, split_index) if intermediate_tensors is not None: @@ -134,7 +135,7 @@ def split_micro_batch( IntermediateTensors(inter_tensors) for inter_tensors in inter_tensors_list ] - return True, attn_metadata, input_tensors, intermediate_tensors + return True, attn_metadata_list, input_tensors, intermediate_tensors def merge_micro_batches( self, input_tensors: Union[List[torch.Tensor], From 85bc104255772e3953e50dc20c4ab89db60eb6cb Mon Sep 17 00:00:00 2001 From: zhuohuan Date: Thu, 5 Jun 2025 23:40:04 +0800 Subject: [PATCH 09/11] [fix]: fix comment issues by separating dbo model Signed-off-by: zhuohuan --- examples/offline_dualbatch_overlap_npu.py | 55 + .../test_offline_inference_distributed.py | 20 +- vllm_ascend/attention/mla_v1.py | 15 +- vllm_ascend/envs.py | 2 + vllm_ascend/models/__init__.py | 5 + vllm_ascend/models/deepseek_dbo.py | 1118 +++++++++++++++++ vllm_ascend/models/deepseek_v2.py | 415 +----- vllm_ascend/multistream/context.py | 2 +- vllm_ascend/multistream/ms_split.py | 2 + 9 files changed, 1203 insertions(+), 431 deletions(-) create mode 100644 examples/offline_dualbatch_overlap_npu.py create mode 100644 vllm_ascend/models/deepseek_dbo.py diff --git a/examples/offline_dualbatch_overlap_npu.py b/examples/offline_dualbatch_overlap_npu.py new file mode 100644 index 0000000000..6c4974647c --- /dev/null +++ b/examples/offline_dualbatch_overlap_npu.py @@ -0,0 +1,55 @@ +import os +import time + +from vllm import LLM, SamplingParams + +# enable dual-batch overlap for vllm ascend +os.environ["VLLM_ASCEND_ENABLE_DBO"] = "1" +os.environ["VLLM_USE_V1"] = "1" + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] * 10 +# Create a sampling params object. +sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + + +def main(): + # Create an LLM. + llm = LLM( + model="deepseek-ai/DeepSeek-V2-Lite", + hf_overrides={ + "architectures": ["DeepseekDBOForCausalLM"], + }, # override the model arch to run the dbo model + enforce_eager=True, + tensor_parallel_size=8, + max_num_seqs=16, + max_model_len=8192, + max_num_batched_tokens=32768, + block_size=128, + compilation_config=1, + gpu_memory_utilization=0.96) + + # Generate texts from the prompts. The output is a list of RequestOutput + # objects that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params) + + # Print the outputs. + print("-" * 50) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") + print("-" * 50) + + # Add a buffer to wait for profiler in the background process + # (in case MP is on) to finish writing profiling output. + time.sleep(10) + + +if __name__ == "__main__": + main() diff --git a/tests/multicard/test_offline_inference_distributed.py b/tests/multicard/test_offline_inference_distributed.py index 3bfcf33be0..b41860c283 100644 --- a/tests/multicard/test_offline_inference_distributed.py +++ b/tests/multicard/test_offline_inference_distributed.py @@ -74,14 +74,19 @@ def test_models_distributed_topk() -> None: top_k=50, top_p=0.9) + with VllmRunner( + "deepseek-ai/DeepSeek-V2-Lite", + dtype=dtype, + tensor_parallel_size=4, + distributed_executor_backend="mp", + ) as vllm_model: + vllm_model.generate(example_prompts, sampling_params) + -def test_deepseek_model_with_dbo(): - os.environ["VLLM_ENABLE_DBO"] = "1" +@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DBO": "1"}) +def test_models_distributed_DeepSeek_dbo(): example_prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", + "Compare and contrast artificial intelligence with human intelligence in terms of processing information.", ] * 10 dtype = "half" max_tokens = 5 @@ -90,5 +95,8 @@ def test_deepseek_model_with_dbo(): dtype=dtype, tensor_parallel_size=4, distributed_executor_backend="mp", + hf_overrides={ + "architectures": ["DeepseekDBOForCausalLM"], + } # override the model arch to the dbo version ) as vllm_model: vllm_model.generate_greedy(example_prompts, max_tokens) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 9713fefc4d..91ddf43888 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -14,6 +14,7 @@ from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig +from vllm_ascend.multistream.context import get_multistream_comm_context from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla @@ -119,7 +120,7 @@ class AscendMLAMetadata: with_prefill_across_dp: bool = False - query_lens: list[int] = None + query_lens: Optional[list[int]] = None # The dimension of the attention heads head_dim: Optional[int] = None attn_mask: torch.Tensor = None @@ -412,6 +413,7 @@ def build( decode=decode_metadata, query_start_loc=query_start_loc, block_tables=block_table, + seq_lens=seq_lens, with_prefill_across_dp=with_prefill_across_dp, ) @@ -600,9 +602,6 @@ def _forward_prefill( attn_output = attn_output.reshape( [num_tokens, self.num_heads * self.v_head_dim]) - # A better way is to modify the communication ops or RowParallel Layer in vllm; - from vllm_ascend.multistream.context import \ - get_multistream_comm_context current_ms_metadata = get_multistream_comm_context() if current_ms_metadata is None: return self.o_proj(attn_output)[0] @@ -710,8 +709,6 @@ def _forward_decode( context_lens=attn_metadata.decode.seq_lens, # type:ignore mla_vheadsize=self.kv_lora_rank, out=attn_output) - from vllm_ascend.multistream.context import \ - get_multistream_comm_context current_ms_metadata = get_multistream_comm_context() if current_ms_metadata is None: return self._v_up_proj_and_o_proj(attn_output) @@ -848,8 +845,6 @@ def forward( # FIX: aicore move should be also placed on the comm stream in dbo, # otherwise it may affect the accuracy # TODO: use an elegant way to overlap - from vllm_ascend.multistream.context import \ - get_multistream_comm_context output_prefill = self._forward_prefill(prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, @@ -868,14 +863,12 @@ def forward( decode_k_nope, decode_k_pe, kv_cache, attn_metadata) else: - from vllm_ascend.multistream.context import \ - get_multistream_comm_context output_decode = self._forward_decode(decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe, kv_cache, attn_metadata) - current_ms_metadata = get_multistream_comm_context() + current_ms_metadata = get_multistream_comm_context() if current_ms_metadata is not None: with torch.npu.stream(current_ms_metadata.comm_stream): output[:num_decode_tokens] = output_decode diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index 2fd7041fcd..96853b63f8 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -107,6 +107,8 @@ # Whether to enable the trace recompiles from pytorch. "VLLM_ASCEND_TRACE_RECOMPILES": lambda: bool(int(os.getenv("VLLM_ASCEND_TRACE_RECOMPILES", '0'))), + "VLLM_ASCEND_ENABLE_DBO": + lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_DBO", '0'))), # Whether to enable the model execute time observe profile. Disable it when # running vllm ascend in production environment. "VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE": diff --git a/vllm_ascend/models/__init__.py b/vllm_ascend/models/__init__.py index e7f021fdb5..4e9a7289e9 100644 --- a/vllm_ascend/models/__init__.py +++ b/vllm_ascend/models/__init__.py @@ -2,6 +2,7 @@ def register_model(): + from .deepseek_dbo import CustomDeepseekDBOForCausalLM # noqa: F401 from .deepseek_mtp import CustomDeepSeekMTP # noqa: F401 from .deepseek_v2 import CustomDeepseekV2ForCausalLM # noqa: F401 from .deepseek_v2 import CustomDeepseekV3ForCausalLM # noqa: F401 @@ -33,3 +34,7 @@ def register_model(): ModelRegistry.register_model( "Qwen3MoeForCausalLM", "vllm_ascend.models.qwen3_moe:CustomQwen3MoeForCausalLM") + + ModelRegistry.register_model( + "DeepseekDBOForCausalLM", + "vllm_ascend.models.deepseek_dbo:CustomDeepseekDBOForCausalLM") diff --git a/vllm_ascend/models/deepseek_dbo.py b/vllm_ascend/models/deepseek_dbo.py new file mode 100644 index 0000000000..c3de6ae932 --- /dev/null +++ b/vllm_ascend/models/deepseek_dbo.py @@ -0,0 +1,1118 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# # Adapted from +# # vllm-project/vllm/blob/main/vllm/model_executor/models/deepseek_v2.py +# # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py +# """Inference-only DeepseekV2/DeepseekV3 model.""" + +from typing import Any, Dict, List, Optional, Union + +import torch +import torch.distributed as dist +import torch_npu +import vllm.envs as envs +from torch import nn +from transformers import PretrainedConfig +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, ModelConfig, VllmConfig +from vllm.distributed import (get_pp_group, + get_tensor_model_parallel_world_size, + get_tp_group, tensor_model_parallel_all_reduce) +from vllm.distributed.parallel_state import get_dp_group +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear, + UnquantizedLinearMethod) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.models.deepseek_v2 import \ + DeepseekV2ForCausalLM # ruff: noqa: E501 +from vllm.model_executor.models.deepseek_v2 import \ + yarn_get_mscale # ruff: noqa: E501 +from vllm.model_executor.models.deepseek_v2 import (DeepseekV2Attention, + DeepseekV2DecoderLayer, + DeepseekV2MLAAttention) +from vllm.model_executor.models.utils import ( + PPMissingLayer, make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) +from vllm.sequence import IntermediateTensors + +import vllm_ascend.envs as envs_ascend +from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.multistream.base import MSEventKey +from vllm_ascend.multistream.context import ( + advance_step_multistream_layer_context, get_multistream_comm_context, + get_multistream_layer_context, set_multistream_context) +from vllm_ascend.multistream.layers import (MultiStreamPostTransformerLayer, + MultiStreamPreTransformerLayer) +from vllm_ascend.multistream.metadata import (MultiStreamConfig, + MultiStreamStepMetadata, + make_multistream_metadata_ds) +from vllm_ascend.multistream.ms_split import compute_split_seq_index +from vllm_ascend.ops.fused_moe import AscendFusedMoE +from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod +from vllm_ascend.utils import dispose_tensor + +VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO +VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2 + + +class CustomDeepseekDBOMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj") + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj") + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + # NOTE: `torch_npu.npu_dequant_swiglu_quant` can only be enabled in dynamic quant + self.is_dynamic_quant = not isinstance( + self.gate_up_proj.quant_method, + UnquantizedLinearMethod) and isinstance( + self.gate_up_proj.quant_method.quant_method, + AscendW8A8DynamicLinearMethod) + + def forward(self, x): + if self.is_dynamic_quant: + x, dynamic_scale = torch_npu.npu_dynamic_quant(x) + x = torch_npu.npu_quant_matmul( + x, + self.gate_up_proj.weight, + self.gate_up_proj.weight_scale, + output_dtype=torch.int32, + ) + x, dynamic_scale = torch_npu.npu_dequant_swiglu_quant( + x=x, + weight_scale=self.gate_up_proj.weight_scale_fp32, + activation_scale=dynamic_scale, + bias=None, + quant_scale=None, + quant_offset=None, + group_index=None, + activate_left=True, + quant_mode=1) + x = torch_npu.npu_quant_matmul( + x, + self.down_proj.weight, + self.down_proj.weight_scale, + pertoken_scale=dynamic_scale, + output_dtype=torch.bfloat16, + ) + if self.down_proj.reduce_results and self.down_proj.tp_size > 1: + x = tensor_model_parallel_all_reduce(x) + return x + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + def _forward_ms_mlp(self, x): + current_ms_metadata = get_multistream_comm_context() + assert current_ms_metadata is not None + if self.is_dynamic_quant: + x, dynamic_scale = torch_npu.npu_dynamic_quant(x) + x = torch_npu.npu_quant_matmul( + x, + self.gate_up_proj.weight, + self.gate_up_proj.weight_scale, + output_dtype=torch.int32, + ) + x, dynamic_scale = torch_npu.npu_dequant_swiglu_quant( + x=x, + weight_scale=self.gate_up_proj.weight_scale_fp32, + activation_scale=dynamic_scale, + bias=None, + quant_scale=None, + quant_offset=None, + group_index=None, + activate_left=True, + quant_mode=1) + x = torch_npu.npu_quant_matmul( + x, + self.down_proj.weight, + self.down_proj.weight_scale, + pertoken_scale=dynamic_scale, + output_dtype=torch.bfloat16, + ) + if self.down_proj.reduce_results and self.down_proj.tp_size > 1: + current_ms_metadata.before_comm_event.record() + with torch.npu.stream(current_ms_metadata.comm_stream): + current_ms_metadata.before_comm_event.wait() + x = tensor_model_parallel_all_reduce(x) + current_ms_metadata.after_comm_event.record() + return x + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + current_ms_metadata.before_comm_event.record() + with torch.npu.stream(current_ms_metadata.comm_stream): + current_ms_metadata.before_comm_event.wait() + x, _ = self.down_proj(x) + current_ms_metadata.after_comm_event.record() + return x + + +class CustomDeepseekDBOMoE(nn.Module): + + top_k: int + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.routed_scaling_factor = config.routed_scaling_factor + self.n_shared_experts = config.n_shared_experts + self.routed_scaling_factor = config.routed_scaling_factor + if self.tp_size > config.n_routed_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.n_routed_experts}.") + + if config.hidden_act != "silu": + raise ValueError(f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now.") + + self.gate = ReplicatedLinear(config.hidden_size, + config.n_routed_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate") + if config.topk_method == "noaux_tc": + self.gate.e_score_correction_bias = nn.Parameter( + torch.empty(config.n_routed_experts)) + else: + self.gate.e_score_correction_bias = None + + self.experts = AscendFusedMoE( + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + prefix=f"{prefix}.experts", + scoring_func=config.scoring_func, + e_score_correction_bias=self.gate.e_score_correction_bias) + + if config.n_shared_experts is not None: + intermediate_size = (config.moe_intermediate_size * + config.n_shared_experts) + self.shared_experts = CustomDeepseekDBOMLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=True, + prefix=f"{prefix}.shared_experts", + ) + CustomDeepseekDBOMoE.top_k = config.num_experts_per_tok + + self.dp_size = get_dp_group().world_size + + self.tp_group = get_tp_group().device_group + self.tp_rank = get_tp_group().rank_in_group + + self.params_dtype = torch.get_default_dtype() + + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + if attn_metadata is None: + attn_metadata = get_forward_context().attn_metadata + # when profile runs, force experts to load balanced tokens + # to avoid high memory consumption on a single rank. + # TODO: need a better flag to indicate whether in profile run or not. + if attn_metadata is None: + # for profile run + is_prefill = True + enable_force_load_balance = True + else: + is_prefill = attn_metadata.num_prefills > 0 + enable_force_load_balance = False + if hasattr(attn_metadata, 'with_prefill_across_dp'): + is_prefill = is_prefill or attn_metadata.with_prefill_across_dp + + num_tokens, hidden_size = hidden_states.shape + + old_hidden_states = hidden_states.clone() + + if self.tp_size > 1: + if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill: + chunks = torch.chunk(hidden_states, self.tp_size, dim=0) + hidden_states = chunks[self.tp_rank] + elif not self.torchair_graph_enabled: + num_padding_tokens = (self.tp_size - + num_tokens % self.tp_size) % self.tp_size + # Pad hidden_states to make it divisible by tp_size to avoid cross-ring AllGatherV on 910B2C + if num_padding_tokens > 0: + hidden_states = nn.functional.pad( + hidden_states, (0, 0, 0, num_padding_tokens)) + chunk_hidden_states = torch.tensor_split(hidden_states, + self.tp_size, + dim=0) + hidden_states = chunk_hidden_states[self.tp_rank] + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + + hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=router_logits, + is_prefill=is_prefill, + top_k=CustomDeepseekDBOMoE.top_k, + enable_force_load_balance=enable_force_load_balance, + ) * self.routed_scaling_factor + + if self.tp_size > 1: + if self.torchair_graph_enabled: + if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill: + final_hidden_states = torch.zeros( + [num_tokens, hidden_size], + dtype=self.params_dtype, + device="npu") + dist.all_gather_into_tensor(final_hidden_states, + hidden_states, self.tp_group) + hidden_states = final_hidden_states + else: + hidden_states = tensor_model_parallel_all_reduce( + hidden_states) + else: + dist.all_gather(list(chunk_hidden_states), hidden_states, + self.tp_group) + hidden_states = torch.cat(chunk_hidden_states, dim=0) + if num_padding_tokens > 0: + hidden_states = hidden_states[:-num_padding_tokens] + + if self.n_shared_experts is not None: + shared_output = self.shared_experts(old_hidden_states) + + if shared_output is not None: + hidden_states = hidden_states + shared_output + + return hidden_states.view(num_tokens, hidden_size) + + # ----------------------------------------- TBO-related -------------------------------------------- + def _forward_ms_op_shared_expert( + self, + hidden_states: torch.Tensor, + ): + shared_output = self.shared_experts._forward_ms_mlp(hidden_states) + return shared_output + + def _forward_ms_op_gate( + self, + hidden_states: torch.Tensor, + ): + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + return router_logits + + def _forward_ms_op_tp_allgather( + self, + hidden_states: torch.Tensor, + chunk_hidden_states: torch.Tensor, + num_tokens: int = 0, + ): + current_ms_metadata = get_multistream_comm_context() + if current_ms_metadata is None: + dist.all_gather(list(chunk_hidden_states), hidden_states, + self.tp_group) + final_hidden_states = torch.cat(chunk_hidden_states, dim=0) + if num_tokens > 0: + final_hidden_states = final_hidden_states[:-num_tokens] + else: + current_ms_metadata.before_comm_event.record() + with torch.npu.stream(current_ms_metadata.comm_stream): + current_ms_metadata.before_comm_event.wait() + dist.all_gather(list(chunk_hidden_states), hidden_states, + self.tp_group) + final_hidden_states = torch.cat(chunk_hidden_states, dim=0) + if num_tokens > 0: + final_hidden_states = final_hidden_states[:-num_tokens] + current_ms_metadata.after_comm_event.record() + return final_hidden_states + + +class CustomDeepseekDBOMLAAttention(DeepseekV2MLAAttention): + + def __init__( + self, + config: PretrainedConfig, + hidden_size: int, + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: Optional[int], + kv_lora_rank: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + nn.Module.__init__(self) + self.hidden_size = hidden_size + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.v_head_dim = v_head_dim + + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + + self.num_heads = num_heads + tp_size = get_tensor_model_parallel_world_size() + assert num_heads % tp_size == 0 + self.num_local_heads = num_heads // tp_size + + self.scaling = self.qk_head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + if self.q_lora_rank is not None: + self.q_a_proj = ReplicatedLinear(self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_a_proj") + self.q_a_layernorm = RMSNorm(self.q_lora_rank, + eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear(q_lora_rank, + self.num_heads * + self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_b_proj") + else: + self.q_proj = ColumnParallelLinear(self.hidden_size, + self.num_heads * + self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_proj") + + self.kv_a_proj_with_mqa = ReplicatedLinear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_a_proj_with_mqa") + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, + eps=config.rms_norm_eps) + self.kv_b_proj = ColumnParallelLinear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_b_proj") + self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") + + if rope_scaling: + rope_scaling["rope_type"] = 'deepseek_yarn' + self.rotary_emb = get_rope(qk_rope_head_dim, + rotary_dim=qk_rope_head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=False) + if rope_scaling: + mscale_all_dim = rope_scaling.get("mscale_all_dim", False) + scaling_factor = rope_scaling["factor"] + mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) + self.scaling = self.scaling * mscale * mscale + + # In the MLA backend, kv_cache includes both k_c and + # pe (i.e. decoupled position embeddings). In particular, + # the concat_and_cache_mla op requires + # k_c.size(1) + k_pe.size(1) == kv_cache.size(2) + # i.e. + # kv_lora_rank + qk_rope_head_dim == head_size + self.mla_attn = Attention( + num_heads=self.num_local_heads, + head_size=self.kv_lora_rank + self.qk_rope_head_dim, + scale=self.scaling, + num_kv_heads=1, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + use_mla=True, + # MLA Args + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.kv_lora_rank, + qk_nope_head_dim=self.qk_nope_head_dim, + qk_rope_head_dim=self.qk_rope_head_dim, + qk_head_dim=self.qk_head_dim, + v_head_dim=self.v_head_dim, + rotary_emb=self.rotary_emb, + q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj, + kv_a_proj_with_mqa=self.kv_a_proj_with_mqa, + kv_a_layernorm=self.kv_a_layernorm, + kv_b_proj=self.kv_b_proj, + o_proj=self.o_proj, + ) + + self.prefix = prefix + self.debug_layer_idx = int(self.prefix.split(".")[-2]) + + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + if self.q_lora_rank is not None: + ckq = self.q_a_proj(hidden_states)[0] + hidden_states_or_q_c = self.q_a_layernorm(ckq) + else: + hidden_states_or_q_c = hidden_states + if self.torchair_graph_enabled: + forward_kwargs = {} + if envs.VLLM_USE_V1: + output_shape = hidden_states.shape + output = torch.empty(output_shape, + dtype=hidden_states_or_q_c.dtype, + device=hidden_states_or_q_c.device) + forward_kwargs['output'] = output + + output = self.mla_attn.impl.forward(self.mla_attn, + hidden_states_or_q_c, + hidden_states, None, kv_cache, + attn_metadata, + **forward_kwargs) + if envs.VLLM_USE_V1: + output = output.view(-1, output_shape[-1]) + return output + else: + kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) + return self.mla_attn(hidden_states_or_q_c, + kv_c_normed, + k_pe, + output_shape=hidden_states.shape) + + +class CustomDeepseekDBODecoderLayer(DeepseekV2DecoderLayer): + + def __init__( + self, + config: PretrainedConfig, + prefix: str, + model_config: ModelConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + nn.Module.__init__(self) + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + # DecoderLayers are created with `make_layers` which passes the prefix + # with the layer's index. + layer_idx = int(prefix.split(sep='.')[-1]) + self.layer_idx = layer_idx + # TODO: enable mla in vllm-ascend + if model_config.use_mla: + attn_cls = CustomDeepseekDBOMLAAttention + else: + attn_cls = DeepseekV2Attention + self.self_attn = attn_cls( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + qk_nope_head_dim=config.qk_nope_head_dim, + qk_rope_head_dim=config.qk_rope_head_dim, + v_head_dim=config.v_head_dim, + q_lora_rank=config.q_lora_rank + if hasattr(config, "q_lora_rank") else None, + kv_lora_rank=config.kv_lora_rank, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + + if (config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0): + self.mlp = CustomDeepseekDBOMoE( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + else: + self.mlp = CustomDeepseekDBOMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.routed_scaling_factor = config.routed_scaling_factor + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None, + ) -> torch.Tensor: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + previous_hidden_states, previous_residual = hidden_states, residual + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + # Dispose hidden_states and residual from the previous layer + # to save npu memory because they're no longer used. + dispose_tensor(previous_hidden_states) + dispose_tensor(previous_residual) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + if hidden_states.dtype == torch.float16: + # Fix FP16 overflow + # We scale both hidden_states and residual before + # rmsnorm, and rmsnorm result would not affect by scale. + hidden_states *= 1. / self.routed_scaling_factor + if self.layer_idx == 0: + # The residual is shared by all layers, we only scale it on + # first layer. + residual *= 1. / self.routed_scaling_factor + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + + if isinstance(self.mlp, CustomDeepseekDBOMoE): + hidden_states = self.mlp(hidden_states, attn_metadata) + else: + hidden_states = self.mlp(hidden_states) + + if isinstance( + self.mlp, + CustomDeepseekDBOMLP) and hidden_states.dtype == torch.float16: + # Fix FP16 overflow + # Scaling the DeepseekV2MLP output, it is the input of + # input_layernorm of next decoder layer. + # The scaling of DeepseekV2MOE output would be done in the forward + # of DeepseekV2MOE + hidden_states *= 1. / self.routed_scaling_factor + + return hidden_states, residual + + # ----------------------------------------- TBO-related -------------------------------------------- + def _forward_ms_layer( + self, + positions: List[torch.Tensor], + hidden_states: List[torch.Tensor], + residual: List[torch.Tensor], + attn_metadata: List[AttentionMetadata], + kv_cache: Optional[torch.Tensor] = None, + is_prefill: bool = False, + ) -> tuple[List[torch.Tensor], List[torch.Tensor]]: + layer_index, ms_metadata, _ = get_multistream_layer_context() + assert layer_index >= 0 and ms_metadata is not None + num_micro_batchs = ms_metadata.ms_config.num_micro_batches + assert isinstance(self.mlp, CustomDeepseekDBOMoE) + assert len(positions) == num_micro_batchs + assert len(hidden_states) == num_micro_batchs + assert residual is not None + assert attn_metadata is not None + num_tokens = [] + hidden_dims = [] + shared_outputs = [] + router_logits = [] + chunk_hidden_states = [] + + # block 1 : attention + # block 2 : attn tp communication + # the attn computation of microbatch 1 can be overlapped with the moe + # communication in the previous layer, and the attn computation of microbatch 2 + # can be overlapped with the attn communication of microbatch 1 + for i in range(num_micro_batchs): + # wait last layer moe finishing communication + ms_metadata.try_wait_event(layer_index - 1, i, + MSEventKey.FFN_AR_FINISH) + context = MultiStreamStepMetadata( + comm_stream=ms_metadata.communicate_stream, + before_comm_event=ms_metadata.ms_events[layer_index][i][ + MSEventKey.ATTN_COM_FINISH], + after_comm_event=ms_metadata.ms_events[layer_index][i][ + MSEventKey.ATTN_AR_FINISH], + ) + + with set_multistream_context(context, i): + forward_context = get_forward_context() + forward_context.attn_metadata = attn_metadata[i] + + # input layernorm + hidden_states[i], residual[ + i] = self._forward_ms_op_input_layernorm( + hidden_states[i], residual[i]) + # attention and tp allreduce + hidden_states[i], residual[i] = self._forward_ms_op_attn( + positions[i], hidden_states[i], residual[i], kv_cache, + attn_metadata[i]) + + # block 3 : shared experts + # if there is an allreduce ops in shared expert, we can overlap it with the computation of the + # shared expert for next microbatch or moe gating + for i in range(num_micro_batchs): + ms_metadata.try_wait_event(layer_index, i, + MSEventKey.ATTN_AR_FINISH) + context = MultiStreamStepMetadata( + comm_stream=ms_metadata.communicate_stream, + before_comm_event=ms_metadata.ms_events[layer_index][i][ + MSEventKey.MOE_SE_COMP_FINISH], + after_comm_event=ms_metadata.ms_events[layer_index][i][ + MSEventKey.MOE_SE_COMM_FINISH], + ) + with set_multistream_context(context, i): + # compute shared expert after finishing ATTN AR + hidden_states[i], residual[ + i] = self._forward_ms_op_post_attn_layernorm( + hidden_states[i], residual[i]) + + num_token, hidden_dim = hidden_states[i].shape + hidden_states[i] = hidden_states[i].view(-1, hidden_dim) + num_tokens.append(num_token) + hidden_dims.append(hidden_dim) + if self.mlp.n_shared_experts is not None: + # TODO: we can move shared expert computation into next block if reduce results is false + shared_output = self.mlp._forward_ms_op_shared_expert( + hidden_states[i]) + shared_outputs.append(shared_output) + + # block 4 : moe + for i in range(num_micro_batchs): + # when profile runs, force experts to load balanced tokens + # to avoid high memory consumption on a single rank. + # TODO: need a better flag to indicate whether in profile run or not. + if attn_metadata[i] is None: + # for profile run + is_prefill = True + enable_force_load_balance = True + else: + is_prefill = attn_metadata[i].num_prefills > 0 + enable_force_load_balance = False + + if self.mlp.tp_size > 1: + num_token, _ = hidden_states[i].shape + padded_num_tokens = (self.mlp.tp_size - num_token % + self.mlp.tp_size) % self.mlp.tp_size + if padded_num_tokens > 0: + hidden_states[i] = nn.functional.pad( + hidden_states[i], (0, 0, 0, padded_num_tokens)) + chunk_hidden_state = torch.tensor_split(hidden_states[i], + self.mlp.tp_size, + dim=0) + chunk_hidden_states.append(chunk_hidden_state) + local_hidden_states = chunk_hidden_state[self.mlp.tp_rank] + else: + local_hidden_states = hidden_states[i] + + router_logit = self.mlp._forward_ms_op_gate(local_hidden_states) + router_logits.append(router_logit) + + if CustomDeepseekDBOMoE.top_k: + real_top_k = CustomDeepseekDBOMoE.top_k + else: + real_top_k = self.mlp.experts.top_k + + hidden_states[i] = self.mlp.experts._forward_ms_fused_moe_comp( + local_hidden_states, router_logits[i], is_prefill, real_top_k, + enable_force_load_balance) + + # the following kernels will be submitted to the comm stream to overlap the computation of the + # moe computation of next microbatch and the attn computation of next layer + context = MultiStreamStepMetadata( + comm_stream=ms_metadata.communicate_stream, + before_comm_event=ms_metadata.ms_events[layer_index][i][ + MSEventKey.FFN_COM_FINISH], + after_comm_event=ms_metadata.ms_events[layer_index][i][ + MSEventKey.MOE_AFTER_COMM], + ) + context.before_comm_event.record() + with torch.npu.stream(ms_metadata.communicate_stream): + context.before_comm_event.wait() + if self.mlp.experts.reduce_results and ( + self.mlp.experts.tp_size > 1 + or self.mlp.experts.ep_size > 1): + hidden_states[i] = tensor_model_parallel_all_reduce( + hidden_states[i]) + hidden_states[ + i] = hidden_states[i] * self.mlp.routed_scaling_factor + context.after_comm_event.record() + + context = MultiStreamStepMetadata( + comm_stream=ms_metadata.communicate_stream, + before_comm_event=ms_metadata.ms_events[layer_index][i][ + MSEventKey.MOE_AFTER_COMM], + after_comm_event=ms_metadata.ms_events[layer_index][i][ + MSEventKey.FFN_AR_FINISH], + ) + with set_multistream_context(context, i): + if self.mlp.tp_size > 1: + hidden_states[i] = self.mlp._forward_ms_op_tp_allgather( + hidden_states[i], chunk_hidden_states[i], + padded_num_tokens) + with torch.npu.stream(ms_metadata.communicate_stream): + # last + if shared_outputs[i] is not None: + hidden_states[i] = hidden_states[i] + shared_outputs[i] + hidden_states[i] = hidden_states[i].view( + num_tokens[i], hidden_dims[i]) + if isinstance(self.mlp, CustomDeepseekDBOMLP + ) and hidden_states[i].dtype == torch.float16: + # Fix FP16 overflow + # Scaling the DeepseekV2MLP output, it is the input of + # input_layernorm of next decoder layer. + # The scaling of DeepseekV2MOE output would be done in the forward + # of DeepseekV2MOE + hidden_states[i] *= 1. / self.routed_scaling_factor + context.after_comm_event.record() + return hidden_states, residual + + # should split ops in Decoder Layer + def _forward_ms_op_input_layernorm( + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + return hidden_states, residual + + def _forward_ms_op_attn( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor, + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + if hidden_states.dtype == torch.float16: + # Fix FP16 overflow + # We scale both hidden_states and residual before + # rmsnorm, and rmsnorm result would not affect by scale. + hidden_states *= 1. / self.routed_scaling_factor + if self.layer_idx == 0: + # The residual is shared by all layers, we only scale it on + # first layer. + residual *= 1. / self.routed_scaling_factor + return hidden_states, residual + + def _forward_ms_op_post_attn_layernorm( + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ): + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + return hidden_states, residual + + +class CustomDeepseekDBOModel(nn.Module): + + fall_back_to_pt_during_load = False + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.first_k_dense_replace = config.first_k_dense_replace + + if get_pp_group().is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens") + else: + self.embed_tokens = PPMissingLayer() + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: CustomDeepseekDBODecoderLayer( + config, + prefix, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + ), + prefix=f"{prefix}.layers") + + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + # tbo related members + if VLLM_ASCEND_ENABLE_DBO: + self.use_mla = model_config.use_mla + self.multistream_config = MultiStreamConfig() + multistream_metadata = make_multistream_metadata_ds( + start_layer=self.start_layer + self.first_k_dense_replace, + end_layer=self.end_layer, + causal_lm=getattr(config, "causal_lm", True), + multistream_config=self.multistream_config, + ) + self.ms_pre_layer = MultiStreamPreTransformerLayer( + multistream_metadata) + self.ms_post_layer = MultiStreamPostTransformerLayer( + multistream_metadata) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: Optional[List[torch.Tensor]] = None, + attn_metadata: Optional[AttentionMetadata] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + num_normal_layers = (self.first_k_dense_replace + if VLLM_ASCEND_ENABLE_DBO and self.can_run_ms() + else self.end_layer - self.start_layer) + + for i in range(self.start_layer, self.start_layer + num_normal_layers): + layer = self.layers[i] + hidden_states, residual = layer( + positions, hidden_states, residual, + kv_caches[i - + self.start_layer] if kv_caches is not None else None, + attn_metadata) + + moe_start_layer = self.start_layer + num_normal_layers + if moe_start_layer != self.end_layer: + # if we enable multistream/dbo, process sparse layers here + hidden_states, residual = self._forward_ms_layers( + positions=positions, + hidden_states=hidden_states, + residual=residual, + moe_start_layer=moe_start_layer, + kv_caches=kv_caches, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def can_run_ms(self): + attn_metadata = get_forward_context().attn_metadata + # support mla attention and V1 engine at present + if not self.use_mla or not envs.VLLM_USE_V1: + return False + # enable prefill overlap + if attn_metadata is None or attn_metadata.num_prefills == 0: + return False + else: + [token_index, seq_index + ] = compute_split_seq_index(attn_metadata.query_lens, + attn_metadata.attn_state, + attn_metadata.num_decode_tokens) + if token_index == 0 or seq_index == 0 or seq_index == len( + attn_metadata.query_lens): + return False + # check whether the total tokens exceed the threshold + if self.multistream_config is None or attn_metadata.num_actual_tokens < self.multistream_config.min_total_tokens_to_split: + return False + return True + + def _forward_ms_layers( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor, + moe_start_layer: int, + kv_caches: Optional[List[torch.Tensor]] = None, + is_prefill: bool = False, + ): + + if moe_start_layer == self.end_layer: + return hidden_states, residual + + attn_metadata, [positions, hidden_states, + residual] = self.ms_pre_layer( + [positions, hidden_states, residual], ) + # the rest layers + for i in range(moe_start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer._forward_ms_layer( + positions=positions, + hidden_states=hidden_states, + residual=residual, + attn_metadata=attn_metadata, + kv_cache=kv_caches[i - self.start_layer] + if kv_caches is not None else None, + is_prefill=is_prefill) + advance_step_multistream_layer_context() + + [hidden_states, + residual] = self.ms_post_layer([hidden_states, residual], ) + return hidden_states, residual + + +class CustomDeepseekDBOForCausalLM(DeepseekV2ForCausalLM): + # add `packed_modules_mapping` in `DeepseekV2ForCausalLM` to support weight merging + packed_modules_mapping = { + "gate_up_proj": ["gate_proj", "up_proj"], + "experts": + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = CustomDeepseekDBOModel(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "model")) + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + else: + self.lm_head = PPMissingLayer() + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: Optional[List[torch.Tensor]] = None, + attn_metadata: Optional[AttentionMetadata] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors, + inputs_embeds) + return hidden_states diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index c69ccf5a14..8a1b8d29fb 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -66,22 +66,12 @@ from vllm.sequence import IntermediateTensors import vllm_ascend.envs as envs_ascend -from vllm_ascend.multistream.base import MSEventKey -from vllm_ascend.multistream.context import ( - advance_step_multistream_layer_context, get_multistream_comm_context, - get_multistream_layer_context, set_multistream_context) -from vllm_ascend.multistream.layers import (MultiStreamPostTransformerLayer, - MultiStreamPreTransformerLayer) -from vllm_ascend.multistream.metadata import (MultiStreamConfig, - MultiStreamStepMetadata, - make_multistream_metadata_ds) -from vllm_ascend.multistream.ms_split import compute_split_seq_index +from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ops.fused_moe import AscendFusedMoE from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod from vllm_ascend.utils import dispose_tensor VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2 -VLLM_ENABLE_DBO: bool = envs_ascend.VLLM_ENABLE_DBO class CustomDeepseekV2MLP(nn.Module): @@ -153,50 +143,6 @@ def forward(self, x): x, _ = self.down_proj(x) return x - def _forward_ms_mlp(self, x): - current_ms_metadata = get_multistream_comm_context() - assert current_ms_metadata is not None - if self.is_dynamic_quant: - x, dynamic_scale = torch_npu.npu_dynamic_quant(x) - x = torch_npu.npu_quant_matmul( - x, - self.gate_up_proj.weight, - self.gate_up_proj.weight_scale, - output_dtype=torch.int32, - ) - x, dynamic_scale = torch_npu.npu_dequant_swiglu_quant( - x=x, - weight_scale=self.gate_up_proj.weight_scale_fp32, - activation_scale=dynamic_scale, - bias=None, - quant_scale=None, - quant_offset=None, - group_index=None, - activate_left=True, - quant_mode=1) - x = torch_npu.npu_quant_matmul( - x, - self.down_proj.weight, - self.down_proj.weight_scale, - pertoken_scale=dynamic_scale, - output_dtype=torch.bfloat16, - ) - if self.down_proj.reduce_results and self.down_proj.tp_size > 1: - current_ms_metadata.before_comm_event.record() - with torch.npu.stream(current_ms_metadata.comm_stream): - current_ms_metadata.before_comm_event.wait() - x = tensor_model_parallel_all_reduce(x) - current_ms_metadata.after_comm_event.record() - return x - gate_up, _ = self.gate_up_proj(x) - x = self.act_fn(gate_up) - current_ms_metadata.before_comm_event.record() - with torch.npu.stream(current_ms_metadata.comm_stream): - current_ms_metadata.before_comm_event.wait() - x, _ = self.down_proj(x) - current_ms_metadata.after_comm_event.record() - return x - class CustomDeepseekV2MoE(nn.Module): @@ -366,55 +312,6 @@ def forward( return hidden_states.view(num_tokens, hidden_size) - # ----------------------------------------- TBO-related -------------------------------------------- - def _forward_ms_op_shared_expert( - self, - hidden_states: torch.Tensor, - ): - shared_output = self.shared_experts._forward_ms_mlp(hidden_states) - return shared_output - - def _forward_ms_op_gate( - self, - hidden_states: torch.Tensor, - ): - # router_logits: (num_tokens, n_experts) - router_logits, _ = self.gate(hidden_states) - return router_logits - - def _forward_ms_op_tp_allgather( - self, - hidden_states: torch.Tensor, - chunk_hidden_states: torch.Tensor, - num_tokens: int = 0, - ): - - if self.tp_size > 1: - current_ms_metadata = get_multistream_comm_context() - if current_ms_metadata is None: - dist.all_gather(list(chunk_hidden_states), hidden_states, - self.tp_group) - final_hidden_states = torch.cat(chunk_hidden_states, dim=0) - #if num_tokens < self.tp_size: - # final_hidden_states = final_hidden_states[:num_tokens] - if num_tokens > 0: - final_hidden_states = final_hidden_states[:-num_tokens] - else: - current_ms_metadata.before_comm_event.record() - with torch.npu.stream(current_ms_metadata.comm_stream): - dist.all_gather(list(chunk_hidden_states), hidden_states, - self.tp_group) - final_hidden_states = torch.cat(chunk_hidden_states, dim=0) - #if num_tokens < self.tp_size: - # final_hidden_states = final_hidden_states[:num_tokens] - if num_tokens > 0: - final_hidden_states = final_hidden_states[:-num_tokens] - - else: - final_hidden_states = hidden_states - - return final_hidden_states - class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention): @@ -708,229 +605,6 @@ def forward( return hidden_states, residual - # ----------------------------------------- TBO-related -------------------------------------------- - def _forward_ms_layer( - self, - positions: List[torch.Tensor], - hidden_states: List[torch.Tensor], - residual: List[torch.Tensor], - attn_metadata: List[AttentionMetadata], - kv_cache: Optional[torch.Tensor] = None, - is_prefill: bool = False, - ) -> tuple[List[torch.Tensor], List[torch.Tensor]]: - layer_index, ms_metadata, attn_metadata = get_multistream_layer_context( - ) - assert layer_index >= 0 and ms_metadata is not None - num_micro_batchs = ms_metadata.ms_config.num_micro_batches - assert isinstance(self.mlp, CustomDeepseekV2MoE) - assert len(positions) == num_micro_batchs - assert len(hidden_states) == num_micro_batchs - assert residual is not None - assert attn_metadata is not None - num_tokens = [] - hidden_dims = [] - shared_outputs = [] - router_logits = [] - chunk_hidden_states = [] - - # block 1 : attention - # block 2 : attn tp communication - # the attn computation of microbatch 1 can be overlapped with the moe - # communication in the previous layer, and the attn computation of microbatch 2 - # can be overlapped with the attn communication of microbatch 1 - for i in range(num_micro_batchs): - # wait last layer moe finishing communication - ms_metadata.try_wait_event(layer_index - 1, i, - MSEventKey.FFN_AR_FINISH) - context = MultiStreamStepMetadata( - comm_stream=ms_metadata.communicate_stream, - before_comm_event=ms_metadata.ms_events[layer_index][i][ - MSEventKey.ATTN_COM_FINISH], - after_comm_event=ms_metadata.ms_events[layer_index][i][ - MSEventKey.ATTN_AR_FINISH], - ) - - with set_multistream_context(context, i): - forward_context = get_forward_context() - layer_index, ms_metadata, attn_metadata = get_multistream_layer_context( - ) - forward_context.attn_metadata = attn_metadata[i] - - # input layernorm - hidden_states[i], residual[ - i] = self._forward_ms_op_input_layernorm( - hidden_states[i], residual[i]) - # attention and tp allreduce - hidden_states[i], residual[i] = self._forward_ms_op_attn( - positions[i], hidden_states[i], residual[i], kv_cache, - attn_metadata[i]) - - # block 3 : shared experts - # if there is an allreduce ops in shared expert, we can overlap it with the computation of the - # shared expert for next microbatch or moe gating - for i in range(num_micro_batchs): - ms_metadata.try_wait_event(layer_index, i, - MSEventKey.ATTN_AR_FINISH) - context = MultiStreamStepMetadata( - comm_stream=ms_metadata.communicate_stream, - before_comm_event=ms_metadata.ms_events[layer_index][i][ - MSEventKey.MOE_SE_COMP_FINISH], - after_comm_event=ms_metadata.ms_events[layer_index][i][ - MSEventKey.MOE_SE_COMM_FINISH], - ) - with set_multistream_context(context, i): - # compute shared expert after finishing ATTN AR - hidden_states[i], residual[ - i] = self._forward_ms_op_post_attn_layernorm( - hidden_states[i], residual[i]) - - num_token, hidden_dim = hidden_states[i].shape - hidden_states[i] = hidden_states[i].view(-1, hidden_dim) - num_tokens.append(num_token) - hidden_dims.append(hidden_dim) - if self.mlp.n_shared_experts is not None: - # TODO: we can move shared expert computation into next block if reduce results is false - shared_output = self.mlp._forward_ms_op_shared_expert( - hidden_states[i]) - shared_outputs.append(shared_output) - - # block 4 : moe - for i in range(num_micro_batchs): - # when profile runs, force experts to load balanced tokens - # to avoid high memory consumption on a single rank. - # TODO: need a better flag to indicate whether in profile run or not. - if attn_metadata[i] is None: - # for profile run - is_prefill = True - enable_force_load_balance = True - else: - is_prefill = attn_metadata[i].num_prefills > 0 - enable_force_load_balance = False - - if self.mlp.tp_size > 1: - num_token, _ = hidden_states[i].shape - padded_num_tokens = (self.mlp.tp_size - num_token % - self.mlp.tp_size) % self.mlp.tp_size - if padded_num_tokens > 0: - hidden_states[i] = nn.functional.pad( - hidden_states[i], (0, 0, 0, padded_num_tokens)) - chunk_hidden_state = torch.tensor_split(hidden_states[i], - self.mlp.tp_size, - dim=0) - chunk_hidden_states.append(chunk_hidden_state) - local_hidden_states = chunk_hidden_state[self.mlp.tp_rank] - else: - local_hidden_states = hidden_states - - router_logit = self.mlp._forward_ms_op_gate(local_hidden_states) - router_logits.append(router_logit) - - if CustomDeepseekV2MoE.top_k: - real_top_k = CustomDeepseekV2MoE.top_k - else: - real_top_k = self.mlp.experts.top_k - - hidden_states[i] = self.mlp.experts._forward_ms_fused_moe_comp( - local_hidden_states, router_logits[i], is_prefill, real_top_k, - enable_force_load_balance) - - # the following kernels will be submitted to the comm stream to overlap the computation of the - # moe computation of next microbatch and the attn computation of next layer - context = MultiStreamStepMetadata( - comm_stream=ms_metadata.communicate_stream, - before_comm_event=ms_metadata.ms_events[layer_index][i][ - MSEventKey.FFN_COM_FINISH], - after_comm_event=ms_metadata.ms_events[layer_index][i][ - MSEventKey.MOE_AFTER_COMM], - ) - context.before_comm_event.record() - with torch.npu.stream(ms_metadata.communicate_stream): - context.before_comm_event.wait() - if self.mlp.experts.reduce_results and ( - self.mlp.experts.tp_size > 1 - or self.mlp.experts.ep_size > 1): - hidden_states[i] = tensor_model_parallel_all_reduce( - hidden_states[i]) - context.after_comm_event.record() - - hidden_states[ - i] = hidden_states[i] * self.mlp.routed_scaling_factor - context = MultiStreamStepMetadata( - comm_stream=ms_metadata.communicate_stream, - before_comm_event=ms_metadata.ms_events[layer_index][i][ - MSEventKey.MOE_AFTER_COMM], - after_comm_event=ms_metadata.ms_events[layer_index][i][ - MSEventKey.FFN_AR_FINISH], - ) - with set_multistream_context(context, i): - hidden_states[i] = self.mlp._forward_ms_op_tp_allgather( - hidden_states[i], chunk_hidden_states[i], - padded_num_tokens) - with torch.npu.stream(ms_metadata.communicate_stream): - # last - if shared_output is not None: - hidden_states[i] = hidden_states[i] + shared_outputs[i] - hidden_states[i] = hidden_states[i].view( - num_tokens[i], hidden_dims[i]) - if isinstance(self.mlp, CustomDeepseekV2MLP - ) and hidden_states[i].dtype == torch.float16: - # Fix FP16 overflow - # Scaling the DeepseekV2MLP output, it is the input of - # input_layernorm of next decoder layer. - # The scaling of DeepseekV2MOE output would be done in the forward - # of DeepseekV2MOE - hidden_states[i] *= 1. / self.routed_scaling_factor - context.after_comm_event.record() - return hidden_states, residual - - # should split ops in Decoder Layer - def _forward_ms_op_input_layernorm( - self, - hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], - ) -> tuple[torch.Tensor, torch.Tensor]: - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) - return hidden_states, residual - - def _forward_ms_op_attn( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - residual: torch.Tensor, - kv_cache: Optional[torch.Tensor] = None, - attn_metadata: Optional[AttentionMetadata] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, - ) - if hidden_states.dtype == torch.float16: - # Fix FP16 overflow - # We scale both hidden_states and residual before - # rmsnorm, and rmsnorm result would not affect by scale. - hidden_states *= 1. / self.routed_scaling_factor - if self.layer_idx == 0: - # The residual is shared by all layers, we only scale it on - # first layer. - residual *= 1. / self.routed_scaling_factor - return hidden_states, residual - - def _forward_ms_op_post_attn_layernorm( - self, - hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], - ): - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) - return hidden_states, residual - class CustomDeepseekV2Model(nn.Module): @@ -946,7 +620,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.first_k_dense_replace = config.first_k_dense_replace if get_pp_group().is_first_rank: self.embed_tokens = VocabParallelEmbedding( @@ -976,21 +649,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) - # tbo related members - if VLLM_ENABLE_DBO: - self.use_mla = model_config.use_mla - self.multistream_config = MultiStreamConfig() - multistream_metadata = make_multistream_metadata_ds( - start_layer=self.start_layer + self.first_k_dense_replace, - end_layer=self.end_layer, - causal_lm=getattr(config, "causal_lm", True), - multistream_config=self.multistream_config, - ) - self.ms_pre_layer = MultiStreamPreTransformerLayer( - multistream_metadata) - self.ms_post_layer = MultiStreamPostTransformerLayer( - multistream_metadata) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -1014,11 +672,7 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - num_normal_layers = (self.first_k_dense_replace if VLLM_ENABLE_DBO - and self.can_run_ms() else self.end_layer - - self.start_layer) - - for i in range(self.start_layer, self.start_layer + num_normal_layers): + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, residual, @@ -1026,17 +680,6 @@ def forward( self.start_layer] if kv_caches is not None else None, attn_metadata) - moe_start_layer = self.start_layer + num_normal_layers - if moe_start_layer != self.end_layer: - # if we enable multistream/dbo, process sparse layers here - hidden_states, residual = self._forward_ms_layers( - positions=positions, - hidden_states=hidden_states, - residual=residual, - moe_start_layer=moe_start_layer, - kv_caches=kv_caches, - ) - if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, @@ -1046,60 +689,6 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def can_run_ms(self): - attn_metadata = get_forward_context().attn_metadata - # support mla attention and V1 engine at present - if not self.use_mla or not envs.VLLM_USE_V1: - return False - # enable prefill overlap - if attn_metadata is None or attn_metadata.num_prefills == 0: - return False - else: - [token_index, seq_index - ] = compute_split_seq_index(attn_metadata.query_lens, - attn_metadata.attn_state, - attn_metadata.num_decode_tokens) - if token_index == 0 or seq_index == 0 or seq_index == len( - attn_metadata.query_lens): - return False - # check whether the total tokens exceed the threshold - if self.multistream_config is None or attn_metadata.num_actual_tokens < self.multistream_config.min_total_tokens_to_split: - return False - return True - - def _forward_ms_layers( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - residual: torch.Tensor, - moe_start_layer: int, - kv_caches: Optional[List[torch.Tensor]] = None, - is_prefill: bool = False, - ): - - if moe_start_layer == self.end_layer: - return hidden_states, residual - - attn_metadata, [positions, hidden_states, - residual] = self.ms_pre_layer( - [positions, hidden_states, residual], ) - # the rest layers - for i in range(moe_start_layer, self.end_layer): - layer = self.layers[i] - hidden_states, residual = layer._forward_ms_layer( - positions=positions, - hidden_states=hidden_states, - residual=residual, - attn_metadata=attn_metadata, - kv_cache=kv_caches[i - self.start_layer] - if kv_caches is not None else None, - is_prefill=is_prefill) - advance_step_multistream_layer_context() - - [hidden_states, - residual] = self.ms_post_layer([hidden_states, residual], ) - return hidden_states, residual - class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM): # add `packed_modules_mapping` in `DeepseekV2ForCausalLM` to support weight merging diff --git a/vllm_ascend/multistream/context.py b/vllm_ascend/multistream/context.py index afc8ba0235..a1684f2f55 100644 --- a/vllm_ascend/multistream/context.py +++ b/vllm_ascend/multistream/context.py @@ -23,7 +23,7 @@ def reset_multistream_layer_context(): """ reset multistream layer context """ - global _ms_layer_index_context, _ms_metadata_context + global _ms_layer_index_context, _ms_metadata_context, _ms_attn_metadata_context _ms_layer_index_context = -1 _ms_metadata_context = None _ms_attn_metadata_context = None diff --git a/vllm_ascend/multistream/ms_split.py b/vllm_ascend/multistream/ms_split.py index 3af6337e47..430f57b03a 100644 --- a/vllm_ascend/multistream/ms_split.py +++ b/vllm_ascend/multistream/ms_split.py @@ -223,6 +223,7 @@ def model_input_split_v1_mla_attn( attn_mask=attn_mask_pre, prefill=prefill_pre, decode=decode_pre, + with_prefill_across_dp=attn_metadata.with_prefill_across_dp, ) attention_metadata_post = _metadata_cls( num_actual_tokens=attn_metadata.num_actual_tokens - token_index, @@ -239,5 +240,6 @@ def model_input_split_v1_mla_attn( attn_state=attn_state_post, prefill=prefill_post, decode=decode_post, + with_prefill_across_dp=attn_metadata.with_prefill_across_dp, ) return [attention_metadata_pre, attention_metadata_post] From f9230b331df9476ece516857d62069da8a2bb168 Mon Sep 17 00:00:00 2001 From: zhuohuan Date: Fri, 6 Jun 2025 18:27:13 +0800 Subject: [PATCH 10/11] [feat]: update tests and example for dbo Signed-off-by: zhuohuan --- examples/offline_dualbatch_overlap_npu.py | 30 +++++++++---------- .../test_offline_inference_distributed.py | 8 ++--- 2 files changed, 18 insertions(+), 20 deletions(-) diff --git a/examples/offline_dualbatch_overlap_npu.py b/examples/offline_dualbatch_overlap_npu.py index 6c4974647c..e7fd9b0358 100644 --- a/examples/offline_dualbatch_overlap_npu.py +++ b/examples/offline_dualbatch_overlap_npu.py @@ -8,31 +8,31 @@ os.environ["VLLM_USE_V1"] = "1" # Sample prompts. -prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", -] * 10 +prompts = ["The president of the United States is"] * 41 # Create a sampling params object. -sampling_params = SamplingParams(temperature=0.8, top_p=0.95) +sampling_params = SamplingParams(max_tokens=100, temperature=0.0) def main(): # Create an LLM. llm = LLM( - model="deepseek-ai/DeepSeek-V2-Lite", + model="deepseek-ai/DeepSeek-V3-Lite-base-latest-w8a8-dynamic", hf_overrides={ "architectures": ["DeepseekDBOForCausalLM"], }, # override the model arch to run the dbo model enforce_eager=True, - tensor_parallel_size=8, - max_num_seqs=16, - max_model_len=8192, - max_num_batched_tokens=32768, - block_size=128, - compilation_config=1, - gpu_memory_utilization=0.96) + tensor_parallel_size=2, + max_model_len=4096, + trust_remote_code=True, + additional_config={ + "torchair_graph_config": { + "enabled": False + }, + "ascend_scheduler_config": { + "enabled": True + }, + "expert_tensor_parallel_size": 1 + }) # Generate texts from the prompts. The output is a list of RequestOutput # objects that contain the prompt, generated text, and other information. diff --git a/tests/multicard/test_offline_inference_distributed.py b/tests/multicard/test_offline_inference_distributed.py index b41860c283..7587ba6bde 100644 --- a/tests/multicard/test_offline_inference_distributed.py +++ b/tests/multicard/test_offline_inference_distributed.py @@ -85,11 +85,9 @@ def test_models_distributed_topk() -> None: @patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DBO": "1"}) def test_models_distributed_DeepSeek_dbo(): - example_prompts = [ - "Compare and contrast artificial intelligence with human intelligence in terms of processing information.", - ] * 10 + example_prompts = ["The president of the United States is"] * 41 dtype = "half" - max_tokens = 5 + sampling_params = SamplingParams(max_tokens=100, temperature=0.0) with VllmRunner( "deepseek-ai/DeepSeek-V2-Lite", dtype=dtype, @@ -99,4 +97,4 @@ def test_models_distributed_DeepSeek_dbo(): "architectures": ["DeepseekDBOForCausalLM"], } # override the model arch to the dbo version ) as vllm_model: - vllm_model.generate_greedy(example_prompts, max_tokens) + vllm_model.generate(example_prompts, sampling_params) From 22cd249a82fad208794f82fac2fad13f257ced83 Mon Sep 17 00:00:00 2001 From: zhuohuan Date: Fri, 6 Jun 2025 22:01:39 +0800 Subject: [PATCH 11/11] [fix]: use env varibles to enable dbo model Signed-off-by: zhuohuan --- examples/offline_dualbatch_overlap_npu.py | 32 ++++++++----------- .../test_offline_inference_distributed.py | 3 -- vllm_ascend/models/__init__.py | 17 ++++++---- 3 files changed, 24 insertions(+), 28 deletions(-) diff --git a/examples/offline_dualbatch_overlap_npu.py b/examples/offline_dualbatch_overlap_npu.py index e7fd9b0358..d8153e38ca 100644 --- a/examples/offline_dualbatch_overlap_npu.py +++ b/examples/offline_dualbatch_overlap_npu.py @@ -15,24 +15,20 @@ def main(): # Create an LLM. - llm = LLM( - model="deepseek-ai/DeepSeek-V3-Lite-base-latest-w8a8-dynamic", - hf_overrides={ - "architectures": ["DeepseekDBOForCausalLM"], - }, # override the model arch to run the dbo model - enforce_eager=True, - tensor_parallel_size=2, - max_model_len=4096, - trust_remote_code=True, - additional_config={ - "torchair_graph_config": { - "enabled": False - }, - "ascend_scheduler_config": { - "enabled": True - }, - "expert_tensor_parallel_size": 1 - }) + llm = LLM(model="deepseek-ai/DeepSeek-V3-Lite-base-latest-w8a8-dynamic", + enforce_eager=True, + tensor_parallel_size=2, + max_model_len=4096, + trust_remote_code=True, + additional_config={ + "torchair_graph_config": { + "enabled": False + }, + "ascend_scheduler_config": { + "enabled": True + }, + "expert_tensor_parallel_size": 1 + }) # Generate texts from the prompts. The output is a list of RequestOutput # objects that contain the prompt, generated text, and other information. diff --git a/tests/multicard/test_offline_inference_distributed.py b/tests/multicard/test_offline_inference_distributed.py index 7587ba6bde..50675cf2d4 100644 --- a/tests/multicard/test_offline_inference_distributed.py +++ b/tests/multicard/test_offline_inference_distributed.py @@ -93,8 +93,5 @@ def test_models_distributed_DeepSeek_dbo(): dtype=dtype, tensor_parallel_size=4, distributed_executor_backend="mp", - hf_overrides={ - "architectures": ["DeepseekDBOForCausalLM"], - } # override the model arch to the dbo version ) as vllm_model: vllm_model.generate(example_prompts, sampling_params) diff --git a/vllm_ascend/models/__init__.py b/vllm_ascend/models/__init__.py index 4e9a7289e9..435778713c 100644 --- a/vllm_ascend/models/__init__.py +++ b/vllm_ascend/models/__init__.py @@ -1,5 +1,7 @@ from vllm import ModelRegistry +import vllm_ascend.envs as envs + def register_model(): from .deepseek_dbo import CustomDeepseekDBOForCausalLM # noqa: F401 @@ -23,9 +25,14 @@ def register_model(): "vllm_ascend.models.qwen2_5_vl:AscendQwen2_5_VLForConditionalGeneration" ) - ModelRegistry.register_model( - "DeepseekV2ForCausalLM", - "vllm_ascend.models.deepseek_v2:CustomDeepseekV2ForCausalLM") + if envs.VLLM_ASCEND_ENABLE_DBO: + ModelRegistry.register_model( + "DeepseekV2ForCausalLM", + "vllm_ascend.models.deepseek_dbo:CustomDeepseekDBOForCausalLM") + else: + ModelRegistry.register_model( + "DeepseekV2ForCausalLM", + "vllm_ascend.models.deepseek_v2:CustomDeepseekV2ForCausalLM") ModelRegistry.register_model( "DeepseekV3ForCausalLM", @@ -34,7 +41,3 @@ def register_model(): ModelRegistry.register_model( "Qwen3MoeForCausalLM", "vllm_ascend.models.qwen3_moe:CustomQwen3MoeForCausalLM") - - ModelRegistry.register_model( - "DeepseekDBOForCausalLM", - "vllm_ascend.models.deepseek_dbo:CustomDeepseekDBOForCausalLM")