Skip to content
51 changes: 51 additions & 0 deletions examples/offline_dualbatch_overlap_npu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import os
import time

from vllm import LLM, SamplingParams

# enable dual-batch overlap for vllm ascend
os.environ["VLLM_ASCEND_ENABLE_DBO"] = "1"
os.environ["VLLM_USE_V1"] = "1"

# Sample prompts.
prompts = ["The president of the United States is"] * 41
# Create a sampling params object.
sampling_params = SamplingParams(max_tokens=100, temperature=0.0)


def main():
# Create an LLM.
llm = LLM(model="deepseek-ai/DeepSeek-V3-Lite-base-latest-w8a8-dynamic",
enforce_eager=True,
tensor_parallel_size=2,
max_model_len=4096,
trust_remote_code=True,
additional_config={
"torchair_graph_config": {
"enabled": False
},
"ascend_scheduler_config": {
"enabled": True
},
"expert_tensor_parallel_size": 1
})

# Generate texts from the prompts. The output is a list of RequestOutput
# objects that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)

# Print the outputs.
print("-" * 50)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)

# Add a buffer to wait for profiler in the background process
# (in case MP is on) to finish writing profiling output.
time.sleep(10)


if __name__ == "__main__":
main()
14 changes: 14 additions & 0 deletions tests/multicard/test_offline_inference_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,17 @@ def test_models_distributed_topk() -> None:
distributed_executor_backend="mp",
) as vllm_model:
vllm_model.generate(example_prompts, sampling_params)


@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DBO": "1"})
def test_models_distributed_DeepSeek_dbo():
example_prompts = ["The president of the United States is"] * 41
dtype = "half"
sampling_params = SamplingParams(max_tokens=100, temperature=0.0)
with VllmRunner(
"deepseek-ai/DeepSeek-V2-Lite",
dtype=dtype,
tensor_parallel_size=4,
distributed_executor_backend="mp",
) as vllm_model:
vllm_model.generate(example_prompts, sampling_params)
69 changes: 61 additions & 8 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@

from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
from vllm_ascend.multistream.context import get_multistream_comm_context
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla

if TYPE_CHECKING:
Expand Down Expand Up @@ -117,6 +120,7 @@ class AscendMLAMetadata:

with_prefill_across_dp: bool = False

query_lens: Optional[list[int]] = None
# The dimension of the attention heads
head_dim: Optional[int] = None
attn_mask: torch.Tensor = None
Expand All @@ -135,6 +139,17 @@ def __post_init__(self):
# f"Only {supported_head_sizes} are supported for head_dim,",
# f"received {self.head_dim}.")

def split_metadata_for_multistream(
self,
ms_split_config: MSAttentionMetadataSplitConfig,
) -> list["AscendMLAMetadata"]:
"""Split metadata for multi-stream with AscendMLAMetadata"""
return model_input_split_v1_mla_attn(
ms_split_config=ms_split_config,
attn_metadata=self,
_metadata_cls=AscendMLAMetadata,
)


M = TypeVar("M", bound=AscendMLAMetadata)

Expand Down Expand Up @@ -386,6 +401,7 @@ def build(

return self.metadata_cls( # type: ignore
num_actual_tokens=num_actual_tokens,
query_lens=query_lens.tolist(),
slot_mapping=slot_mapping,
head_dim=self.runner.model_config.get_head_size(),
num_decodes=self._num_decodes,
Expand Down Expand Up @@ -585,7 +601,15 @@ def _forward_prefill(
)
attn_output = attn_output.reshape(
[num_tokens, self.num_heads * self.v_head_dim])
return self.o_proj(attn_output)[0]

current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is None:
return self.o_proj(attn_output)[0]
else:
current_ms_metadata.before_comm_event.record()
with torch.npu.stream(current_ms_metadata.comm_stream):
current_ms_metadata.before_comm_event.wait()
return self.o_proj(attn_output)[0]

def exec_kv(
self,
Expand Down Expand Up @@ -685,7 +709,14 @@ def _forward_decode(
context_lens=attn_metadata.decode.seq_lens, # type:ignore
mla_vheadsize=self.kv_lora_rank,
out=attn_output)
return self._v_up_proj_and_o_proj(attn_output)
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is None:
return self._v_up_proj_and_o_proj(attn_output)
else:
current_ms_metadata.before_comm_event.record()
with torch.npu.stream(current_ms_metadata.comm_stream):
current_ms_metadata.before_comm_event.wait()
return self._v_up_proj_and_o_proj(attn_output)

def forward(
self,
Expand Down Expand Up @@ -811,16 +842,38 @@ def forward(
key_cache=kv_cache,
slot_indices=attn_metadata.slot_mapping.flatten())
if has_prefill:
output[num_decode_tokens:] = self._forward_prefill(
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
attn_metadata)
# FIX: aicore move should be also placed on the comm stream in dbo,
# otherwise it may affect the accuracy
# TODO: use an elegant way to overlap
output_prefill = self._forward_prefill(prefill_q,
prefill_k_c_normed,
prefill_k_pe, kv_cache,
attn_metadata)
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is not None:
with torch.npu.stream(current_ms_metadata.comm_stream):
output[num_decode_tokens:] = output_prefill
current_ms_metadata.after_comm_event.record()
else:
output[num_decode_tokens:] = output_prefill

if has_decode:
if self.running_in_graph:
return self._forward_decode(decode_ql_nope, decode_q_pe,
decode_k_nope, decode_k_pe,
kv_cache, attn_metadata)
else:
output[:num_decode_tokens] = self._forward_decode(
decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe,
kv_cache, attn_metadata)
output_decode = self._forward_decode(decode_ql_nope,
decode_q_pe,
decode_k_nope,
decode_k_pe, kv_cache,
attn_metadata)
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is not None:
with torch.npu.stream(current_ms_metadata.comm_stream):
output[:num_decode_tokens] = output_decode
current_ms_metadata.after_comm_event.record()
else:
output[:num_decode_tokens] = output_decode

return output_padded
2 changes: 2 additions & 0 deletions vllm_ascend/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@
# Whether to enable the trace recompiles from pytorch.
"VLLM_ASCEND_TRACE_RECOMPILES":
lambda: bool(int(os.getenv("VLLM_ASCEND_TRACE_RECOMPILES", '0'))),
"VLLM_ASCEND_ENABLE_DBO":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_DBO", '0'))),
# Whether to enable the model execute time observe profile. Disable it when
# running vllm ascend in production environment.
"VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE":
Expand Down
14 changes: 11 additions & 3 deletions vllm_ascend/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from vllm import ModelRegistry

import vllm_ascend.envs as envs


def register_model():
from .deepseek_dbo import CustomDeepseekDBOForCausalLM # noqa: F401
from .deepseek_mtp import CustomDeepSeekMTP # noqa: F401
from .deepseek_v2 import CustomDeepseekV2ForCausalLM # noqa: F401
from .deepseek_v2 import CustomDeepseekV3ForCausalLM # noqa: F401
Expand All @@ -22,9 +25,14 @@ def register_model():
"vllm_ascend.models.qwen2_5_vl:AscendQwen2_5_VLForConditionalGeneration"
)

ModelRegistry.register_model(
"DeepseekV2ForCausalLM",
"vllm_ascend.models.deepseek_v2:CustomDeepseekV2ForCausalLM")
if envs.VLLM_ASCEND_ENABLE_DBO:
ModelRegistry.register_model(
"DeepseekV2ForCausalLM",
"vllm_ascend.models.deepseek_dbo:CustomDeepseekDBOForCausalLM")
else:
ModelRegistry.register_model(
"DeepseekV2ForCausalLM",
"vllm_ascend.models.deepseek_v2:CustomDeepseekV2ForCausalLM")

ModelRegistry.register_model(
"DeepseekV3ForCausalLM",
Expand Down
Loading