Skip to content

Conversation

@shikang-hangzhou
Copy link

@shikang-hangzhou shikang-hangzhou commented Jul 2, 2025

What this PR does / why we need it?

  1. DBO model support EP parallel
  2. optimize dual stream overlap

max tokens:32784 input_len:1024 bs 32 dp2tp8ep16
before open dbo
TTFT: 4017ms

before

after open dbo
TTFT: 3017ms
after

Does this PR introduce any user-facing change?

None

How was this patch tested?

Signed-off-by: shikang-hangzhou <459956190@qq.com>
Signed-off-by: shikang-hangzhou <459956190@qq.com>
attn_cls = CustomDeepseekDBOMLAAttention
else:
attn_cls = DeepseekV2Attention
attn_cls = CustomDeepseekV2MLAAttention
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why remove the branch when use_mla is False here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dual stream overlap is a kind of optimized mode. deepseek-mha did not include in our application scenario, and have no improvements. So I think mha mode is useless.

hidden_states[i], router_logits[i], is_prefill, real_top_k,
enable_force_load_balance)

if global_num_experts == 256:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add a comment that we use 256 here because the op npu_moe_gating_top_k only support this

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add a comment that we use 256 here because the op npu_moe_gating_top_k only support this

Thanks for your review, we have add comments

if self.dp_size > 1:
if (self.tp_size > 1
and fused_moe_state != FusedMoEState.AllGather):
dist.all_gather(list(chunk_hidden_states[i]),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recomand to use tensor_model_parallel_all_gather directly

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here align with the deepseekv2 code

Signed-off-by: shikang-hangzhou <459956190@qq.com>
MSEventKey.MOE_ALL_TO_ALL_FINISH],
)
context.before_comm_event.record()
with torch.npu.stream(ms_metadata.communicate_stream):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This kind of stream control method seems can't be captured in torchair, so this is just a eager mode dual batch impl right?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes,so dual stream overlap only affect in prefill process.


for i in range(num_micro_batchs):
ms_metadata.try_wait_event(layer_index, i,
MSEventKey.MOE_ALL_TO_ALL_FINISH)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the difference between event.wait and try_wait_event ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its same here, I will modify it.

Copy link
Author

@shikang-hangzhou shikang-hangzhou Jul 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the difference between event.wait and try_wait_event ?

sorry, its diff between event.wait and try_wait_event. the last could assign num of microbatch which need wait

ep_group.world_size, -1).sum(-1)
scatter_sizes.append(scatter_size)
gather_sizes = torch.empty_like(scatter_sizes[i])
dist.all_to_all_single(gather_sizes,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if my understand is correct, you are trying to overlap the all_to_all with gating_topk right, since the second stream launch needs to wait for the end of gating

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And you overlap the combine phase of all to all with the calc of the shared expert.

Signed-off-by: shikang-hangzhou <459956190@qq.com>
Signed-off-by: shikang-hangzhou <459956190@qq.com>
@ganyi1996ppo ganyi1996ppo merged commit 4f007e8 into vllm-project:v0.9.1-dev Jul 4, 2025
16 checks passed
@Yikun Yikun added the no-main label Jul 7, 2025
22dimensions pushed a commit to 22dimensions/vllm-ascend that referenced this pull request Jul 22, 2025
…overlap (vllm-project#1589)

1. DBO model support EP parallel
2. optimize dual stream overlap

max tokens:32784 input_len:1024  bs 32 dp2tp8ep16
before open dbo
TTFT: 4017ms

![before](https://github.com/user-attachments/assets/8f9e338d-978f-42cf-9add-825a8dd3418f)

after open dbo
TTFT: 3017ms

![after](https://github.com/user-attachments/assets/79f706fa-22c8-4c71-b5e3-ae3f53dac23b)
None

---------

Signed-off-by: shikang-hangzhou <459956190@qq.com>
22dimensions pushed a commit to 22dimensions/vllm-ascend that referenced this pull request Jul 22, 2025
…m-project#1420 vllm-project#1328 from v0.9.1-dev to main

Signed-off-by: 22dimensions <waitingwind@foxmail.com>
22dimensions pushed a commit to 22dimensions/vllm-ascend that referenced this pull request Jul 23, 2025
…m-project#1420 vllm-project#1328 from v0.9.1-dev to main

Signed-off-by: 22dimensions <waitingwind@foxmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants