|
29 | 29 | import numpy.typing as npt |
30 | 30 | import torch |
31 | 31 | import torch.nn as nn |
| 32 | +from torch.distributed import ReduceOp |
32 | 33 | from vllm.attention import AttentionType, get_attn_backend |
33 | 34 | from vllm.attention.layer import Attention |
34 | 35 | from vllm.config import CompilationLevel, VllmConfig |
|
59 | 60 |
|
60 | 61 | from vllm_ascend.attention.attention import AttentionMaskBuilder |
61 | 62 | from vllm_ascend.attention.attention_v1 import AscendAttentionState |
| 63 | +from vllm_ascend.patch.platform.patch_common.patch_distributed import \ |
| 64 | + get_dp_group |
62 | 65 | from vllm_ascend.platform import NPUPlatform |
63 | 66 | from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler |
64 | 67 |
|
@@ -318,6 +321,8 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): |
318 | 321 | False) and self.vllm_config.model_config.use_mla |
319 | 322 | self.use_cached_npu_graph = additional_config.get( |
320 | 323 | "use_cached_npu_graph", False) |
| 324 | + self.has_prefilled = False |
| 325 | + self.dp_group = get_dp_group() |
321 | 326 |
|
322 | 327 | def _update_states(self, scheduler_output: "SchedulerOutput") -> None: |
323 | 328 | """Update the cached states and the persistent batch with the scheduler |
@@ -624,6 +629,9 @@ def _process_reqs( |
624 | 629 | input_ids = torch.cat([input_ids, padding]) |
625 | 630 | positions = torch.cat([positions, padding]) |
626 | 631 |
|
| 632 | + if self.enable_torchair_graph_mode: |
| 633 | + self.sync_prefill_when_enable_graph(attn_metadata) |
| 634 | + |
627 | 635 | # Run forward pass |
628 | 636 | with set_forward_context(attn_metadata, |
629 | 637 | self.vllm_config, |
@@ -685,6 +693,41 @@ def _process_reqs( |
685 | 693 | return (attn_metadata, hidden_states, spec_decode_metadata, positions, |
686 | 694 | total_num_scheduled_tokens, sample_indices) |
687 | 695 |
|
| 696 | + def sync_prefill_when_enable_graph(self, attn_metadata): |
| 697 | + """ |
| 698 | + NOTE: This method serves as a temporary solution to the deadlock issue under the p and d in graph mode. |
| 699 | + It will be removed along with its related calls once the official solution is implemented. |
| 700 | + """ |
| 701 | + |
| 702 | + def has_prefilled_all_rank(has_prefilled: bool) -> bool: |
| 703 | + status = torch.tensor([has_prefilled], |
| 704 | + dtype=torch.int32, |
| 705 | + device="cpu") |
| 706 | + if self.dp_group: |
| 707 | + torch.distributed.all_reduce(status, |
| 708 | + op=ReduceOp.MIN, |
| 709 | + group=self.dp_group) |
| 710 | + aggregated_has_prefilled = bool(status.item()) |
| 711 | + return aggregated_has_prefilled |
| 712 | + |
| 713 | + if self.has_prefilled and not attn_metadata.attn_state == AscendAttentionState.DecodeOnly: |
| 714 | + self.has_prefilled = False |
| 715 | + |
| 716 | + if not self.has_prefilled: |
| 717 | + self.has_prefilled = has_prefilled_all_rank( |
| 718 | + attn_metadata.attn_state == AscendAttentionState.DecodeOnly) |
| 719 | + |
| 720 | + if self.dp_group: |
| 721 | + while not self.has_prefilled and attn_metadata.attn_state == AscendAttentionState.DecodeOnly: |
| 722 | + self._dummy_run(1) |
| 723 | + tensor = torch.tensor([1], dtype=torch.int32, device="cpu") |
| 724 | + torch.distributed.all_reduce(tensor, |
| 725 | + op=ReduceOp.MAX, |
| 726 | + group=self.dp_group) |
| 727 | + self.has_prefilled = has_prefilled_all_rank( |
| 728 | + attn_metadata.attn_state == |
| 729 | + AscendAttentionState.DecodeOnly) |
| 730 | + |
688 | 731 | def _calc_spec_decode_metadata( |
689 | 732 | self, |
690 | 733 | num_draft_tokens: np.ndarray, |
|
0 commit comments