|
77 | 77 | from vllm_ascend.attention.attention import AttentionMaskBuilder |
78 | 78 | from vllm_ascend.attention.attention_v1 import AscendAttentionState |
79 | 79 | from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata |
| 80 | +from vllm_ascend.multistream.ms_split import compute_split_seq_index |
80 | 81 | from vllm_ascend.platform import NPUPlatform |
81 | 82 | from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler |
82 | 83 | from vllm_ascend.utils import ProfileExecuteDuration |
@@ -569,16 +570,38 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: |
569 | 570 | self.input_batch.refresh_sampling_metadata() |
570 | 571 |
|
571 | 572 | def _get_forward_metadata_across_dp( |
572 | | - self, total_num_scheduled_tokens: int, |
573 | | - with_prefill: bool) -> tuple[int, bool]: |
| 573 | + self, total_num_scheduled_tokens: int, with_prefill: bool, |
| 574 | + enable_dbo: bool) -> tuple[int, bool, bool]: |
574 | 575 | forward_metadata = torch.tensor( |
575 | | - [total_num_scheduled_tokens, with_prefill], |
| 576 | + [total_num_scheduled_tokens, with_prefill, not enable_dbo], |
576 | 577 | device="cpu", |
577 | 578 | dtype=torch.int32) |
578 | 579 | dist.all_reduce(forward_metadata, |
579 | 580 | op=ReduceOp.MAX, |
580 | 581 | group=get_dp_group().cpu_group) |
581 | | - return int(forward_metadata[0]), bool(forward_metadata[1] > 0) |
| 582 | + return int(forward_metadata[0]), bool( |
| 583 | + forward_metadata[1] > 0), not bool(forward_metadata[2] > 0) |
| 584 | + |
| 585 | + def _check_dbo_is_valid(self, query_lens: torch.Tensor, |
| 586 | + attn_state: AscendAttentionState, |
| 587 | + num_tokens: int) -> bool: |
| 588 | + # do the checks for dp + dbo |
| 589 | + if attn_state in [ |
| 590 | + AscendAttentionState.DecodeOnly, |
| 591 | + AscendAttentionState.SpecDecoding |
| 592 | + ]: |
| 593 | + return False |
| 594 | + # considering the case that one dp rank may enable dbo while others may not |
| 595 | + if not self.vllm_config.model_config.use_mla or not envs_ascend.VLLM_ASCEND_ENABLE_DBO: |
| 596 | + return False |
| 597 | + # TODO: remove it if token-level microbatch is enabled |
| 598 | + [token_index, |
| 599 | + seq_index] = compute_split_seq_index(query_lens, attn_state, |
| 600 | + num_tokens) |
| 601 | + if token_index == 0 or seq_index == 0 or seq_index == len( |
| 602 | + query_lens) or num_tokens < 256: |
| 603 | + return False |
| 604 | + return True |
582 | 605 |
|
583 | 606 | def get_model(self) -> nn.Module: |
584 | 607 | return self.model |
@@ -900,12 +923,16 @@ def _process_reqs( |
900 | 923 | with_prefill = attn_state not in [ |
901 | 924 | AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding |
902 | 925 | ] |
| 926 | + enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(), |
| 927 | + attn_state, |
| 928 | + total_num_scheduled_tokens) |
903 | 929 |
|
904 | 930 | if self.dp_size > 1: |
905 | | - max_num_tokens, with_prefill = self._get_forward_metadata_across_dp( |
906 | | - total_num_scheduled_tokens, with_prefill) |
| 931 | + max_num_tokens, with_prefill, enable_dbo = self._get_forward_metadata_across_dp( |
| 932 | + total_num_scheduled_tokens, with_prefill, enable_dbo) |
907 | 933 | extra_builder_kwargs['max_num_tokens_across_dp'] = max_num_tokens |
908 | 934 | extra_builder_kwargs['with_prefill_across_dp'] = with_prefill |
| 935 | + extra_builder_kwargs['enable_dbo_across_dp'] = enable_dbo |
909 | 936 |
|
910 | 937 | # Add graph_pad_size here |
911 | 938 | if self.torchair_graph_enabled and not with_prefill: |
|
0 commit comments