Skip to content

Commit f1353d5

Browse files
[0.9.1][Bugfix] fix dp error in dbo (#1291)
Fix running error in dbo when dp_size>1. Add conditional logic in `_get_forward_metadata_across_dp` to enable dbo. Signed-off-by: shikang-hangzhou <459956190@qq.com>
1 parent 6a3551e commit f1353d5

File tree

6 files changed

+53
-31
lines changed

6 files changed

+53
-31
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ class AscendMetadata:
134134
num_input_tokens: int = 0 # Number of tokens including padding.
135135

136136
with_prefill_across_dp: bool = False
137+
enable_dbo_across_dp: bool = False
137138

138139

139140
class AscendAttentionMetadataBuilder:
@@ -150,7 +151,8 @@ def build(self,
150151
num_actual_tokens,
151152
max_query_len,
152153
common_prefix_len,
153-
with_prefill_across_dp: bool = False):
154+
with_prefill_across_dp: bool = False,
155+
enable_dbo_across_dp: bool = False):
154156

155157
block_table = self.runner.input_batch.block_table[0].get_device_tensor(
156158
)
@@ -177,7 +179,8 @@ def build(self,
177179
slot_mapping=slot_mapping,
178180
attn_mask=attn_mask,
179181
attn_state=attn_state,
180-
with_prefill_across_dp=with_prefill_across_dp)
182+
with_prefill_across_dp=with_prefill_across_dp,
183+
enable_dbo_across_dp=enable_dbo_across_dp)
181184
return attn_metadata
182185

183186

vllm_ascend/attention/mla_v1.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ class AscendMLAMetadata:
138138

139139
max_num_tokens_across_dp: int = 0
140140
with_prefill_across_dp: bool = False
141+
enable_dbo_across_dp: bool = False
141142

142143
query_lens: Optional[list[int]] = None
143144
# The dimension of the attention heads
@@ -367,6 +368,7 @@ def build(
367368
graph_pad_size: int = -1,
368369
max_num_tokens_across_dp: int = 0,
369370
with_prefill_across_dp: bool = False,
371+
enable_dbo_across_dp: bool = False,
370372
) -> AscendMLAMetadata:
371373
assert self._num_decodes + self._num_prefills == num_reqs
372374

@@ -513,7 +515,7 @@ def build(
513515
seq_lens=seq_lens,
514516
max_num_tokens_across_dp=max_num_tokens_across_dp,
515517
with_prefill_across_dp=with_prefill_across_dp,
516-
)
518+
enable_dbo_across_dp=enable_dbo_across_dp)
517519

518520

519521
class AscendMLAImpl(MLAAttentionImpl):

vllm_ascend/models/deepseek_dbo.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@
7474
from vllm_ascend.multistream.metadata import (MultiStreamConfig,
7575
MultiStreamStepMetadata,
7676
make_multistream_metadata_ds)
77-
from vllm_ascend.multistream.ms_split import compute_split_seq_index
7877
from vllm_ascend.ops.fused_moe import AscendFusedMoE
7978
from vllm_ascend.utils import dispose_tensor
8079

@@ -881,22 +880,8 @@ def forward(
881880

882881
def can_run_ms(self):
883882
attn_metadata = get_forward_context().attn_metadata
884-
# support mla attention and V1 engine at present
885-
if not self.use_mla or not envs.VLLM_USE_V1:
886-
return False
887883
# enable prefill overlap
888-
if attn_metadata is None or attn_metadata.num_prefills == 0:
889-
return False
890-
else:
891-
[token_index, seq_index
892-
] = compute_split_seq_index(attn_metadata.query_lens,
893-
attn_metadata.attn_state,
894-
attn_metadata.num_decode_tokens)
895-
if token_index == 0 or seq_index == 0 or seq_index == len(
896-
attn_metadata.query_lens):
897-
return False
898-
# check whether the total tokens exceed the threshold
899-
if self.multistream_config is None or attn_metadata.num_actual_tokens < self.multistream_config.min_total_tokens_to_split:
884+
if attn_metadata is None or attn_metadata.num_prefills == 0 or not attn_metadata.enable_dbo_across_dp:
900885
return False
901886
return True
902887

vllm_ascend/multistream/ms_split.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,12 @@ def model_input_split_v1_mla_attn(
9696
seq_lens = attn_metadata.prefill.seq_lens if attn_metadata.num_prefills > 0 else attn_metadata.decode.seq_lens
9797
[seq_lens_pre, seq_lens_post] = split_attn_tensor_type(seq_lens, seq_index)
9898

99-
query_start_loc_pre = attn_metadata.query_start_loc[:seq_index + 1]
100-
query_start_loc_post = deepcopy(
101-
attn_metadata.query_start_loc[seq_index:]
102-
) - attn_metadata.query_start_loc[seq_index]
99+
query_start_loc_pre = query_start_loc_post = None
100+
if attn_metadata.query_start_loc is not None:
101+
query_start_loc_pre = attn_metadata.query_start_loc[:seq_index + 1]
102+
query_start_loc_post = deepcopy(
103+
attn_metadata.query_start_loc[seq_index:]
104+
) - attn_metadata.query_start_loc[seq_index]
103105
[block_table_pre,
104106
block_table_post] = split_attn_tensor_type(attn_metadata.block_tables,
105107
seq_index)
@@ -224,6 +226,7 @@ def model_input_split_v1_mla_attn(
224226
prefill=prefill_pre,
225227
decode=decode_pre,
226228
with_prefill_across_dp=attn_metadata.with_prefill_across_dp,
229+
enable_dbo_across_dp=attn_metadata.enable_dbo_across_dp,
227230
)
228231
attention_metadata_post = _metadata_cls(
229232
num_actual_tokens=attn_metadata.num_actual_tokens - token_index,
@@ -241,5 +244,6 @@ def model_input_split_v1_mla_attn(
241244
prefill=prefill_post,
242245
decode=decode_post,
243246
with_prefill_across_dp=attn_metadata.with_prefill_across_dp,
247+
enable_dbo_across_dp=attn_metadata.enable_dbo_across_dp,
244248
)
245249
return [attention_metadata_pre, attention_metadata_post]

vllm_ascend/worker/model_runner_v1.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
from vllm_ascend.attention.attention import AttentionMaskBuilder
7878
from vllm_ascend.attention.attention_v1 import AscendAttentionState
7979
from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata
80+
from vllm_ascend.multistream.ms_split import compute_split_seq_index
8081
from vllm_ascend.platform import NPUPlatform
8182
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
8283
from vllm_ascend.utils import ProfileExecuteDuration
@@ -569,16 +570,38 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
569570
self.input_batch.refresh_sampling_metadata()
570571

571572
def _get_forward_metadata_across_dp(
572-
self, total_num_scheduled_tokens: int,
573-
with_prefill: bool) -> tuple[int, bool]:
573+
self, total_num_scheduled_tokens: int, with_prefill: bool,
574+
enable_dbo: bool) -> tuple[int, bool, bool]:
574575
forward_metadata = torch.tensor(
575-
[total_num_scheduled_tokens, with_prefill],
576+
[total_num_scheduled_tokens, with_prefill, not enable_dbo],
576577
device="cpu",
577578
dtype=torch.int32)
578579
dist.all_reduce(forward_metadata,
579580
op=ReduceOp.MAX,
580581
group=get_dp_group().cpu_group)
581-
return int(forward_metadata[0]), bool(forward_metadata[1] > 0)
582+
return int(forward_metadata[0]), bool(
583+
forward_metadata[1] > 0), not bool(forward_metadata[2] > 0)
584+
585+
def _check_dbo_is_valid(self, query_lens: torch.Tensor,
586+
attn_state: AscendAttentionState,
587+
num_tokens: int) -> bool:
588+
# do the checks for dp + dbo
589+
if attn_state in [
590+
AscendAttentionState.DecodeOnly,
591+
AscendAttentionState.SpecDecoding
592+
]:
593+
return False
594+
# considering the case that one dp rank may enable dbo while others may not
595+
if not self.vllm_config.model_config.use_mla or not envs_ascend.VLLM_ASCEND_ENABLE_DBO:
596+
return False
597+
# TODO: remove it if token-level microbatch is enabled
598+
[token_index,
599+
seq_index] = compute_split_seq_index(query_lens, attn_state,
600+
num_tokens)
601+
if token_index == 0 or seq_index == 0 or seq_index == len(
602+
query_lens) or num_tokens < 256:
603+
return False
604+
return True
582605

583606
def get_model(self) -> nn.Module:
584607
return self.model
@@ -900,12 +923,16 @@ def _process_reqs(
900923
with_prefill = attn_state not in [
901924
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
902925
]
926+
enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(),
927+
attn_state,
928+
total_num_scheduled_tokens)
903929

904930
if self.dp_size > 1:
905-
max_num_tokens, with_prefill = self._get_forward_metadata_across_dp(
906-
total_num_scheduled_tokens, with_prefill)
931+
max_num_tokens, with_prefill, enable_dbo = self._get_forward_metadata_across_dp(
932+
total_num_scheduled_tokens, with_prefill, enable_dbo)
907933
extra_builder_kwargs['max_num_tokens_across_dp'] = max_num_tokens
908934
extra_builder_kwargs['with_prefill_across_dp'] = with_prefill
935+
extra_builder_kwargs['enable_dbo_across_dp'] = enable_dbo
909936

910937
# Add graph_pad_size here
911938
if self.torchair_graph_enabled and not with_prefill:

vllm_ascend/worker/worker_v1.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,9 +251,10 @@ def execute_dummy_batch(self) -> None:
251251
runner = self.model_runner
252252
max_num_tokens = 1
253253
with_prefill = False
254+
enable_dbo = False
254255
if runner.dp_size > 1:
255-
max_num_tokens, with_prefill = runner._get_forward_metadata_across_dp(
256-
max_num_tokens, with_prefill)
256+
max_num_tokens, with_prefill, _ = runner._get_forward_metadata_across_dp(
257+
max_num_tokens, with_prefill, enable_dbo)
257258
if runner.torchair_graph_enabled and not with_prefill:
258259
max_num_tokens = runner.select_torchair_padded_batch_size(
259260
max_num_tokens)

0 commit comments

Comments
 (0)