|
77 | 77 | from vllm_ascend.ascend_forward_context import set_ascend_forward_context |
78 | 78 | from vllm_ascend.attention.attention import AttentionMaskBuilder |
79 | 79 | from vllm_ascend.attention.attention_v1 import AscendAttentionState |
80 | | -from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata |
| 80 | +from vllm_ascend.attention.utils import \ |
| 81 | + AscendCommonAttentionMetadata as CommonAttentionMetadata |
81 | 82 | from vllm_ascend.multistream.ms_split import compute_split_seq_index |
82 | 83 | from vllm_ascend.platform import NPUPlatform |
83 | 84 | from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler |
@@ -253,6 +254,9 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): |
253 | 254 | self.slot_mapping = torch.zeros(self.max_num_tokens, |
254 | 255 | dtype=torch.int32, |
255 | 256 | device=self.device) |
| 257 | + self.query_lens = torch.zeros(self.max_num_reqs, |
| 258 | + dtype=torch.int32, |
| 259 | + device=self.device) |
256 | 260 | # None in the first PP rank. The rest are set after load_model. |
257 | 261 | self.intermediate_tensors: Optional[IntermediateTensors] = None |
258 | 262 |
|
@@ -1528,6 +1532,7 @@ def _dummy_run( |
1528 | 1532 | skip_attn: bool = True, |
1529 | 1533 | with_prefill: bool = False, |
1530 | 1534 | is_torchair_compile: bool = False, |
| 1535 | + attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly, |
1531 | 1536 | ) -> torch.Tensor: |
1532 | 1537 | if self.torchair_graph_enabled and not with_prefill: |
1533 | 1538 | num_tokens = self.select_torchair_padded_batch_size(num_tokens) |
@@ -1558,43 +1563,13 @@ def _dummy_run( |
1558 | 1563 | elif skip_attn: |
1559 | 1564 | attn_metadata = None |
1560 | 1565 | else: |
1561 | | - query_start_loc = self.query_start_loc[:num_reqs + 1] |
1562 | | - query_start_loc[:] = torch.arange( |
1563 | | - query_start_loc.numel(), |
1564 | | - device=query_start_loc.device, |
1565 | | - dtype=query_start_loc.dtype, |
1566 | | - ) |
1567 | | - seq_lens = self.seq_lens_np[:num_reqs] |
1568 | | - seq_lens[:] = seq_lens + 2 |
1569 | | - self.seq_lens_list = self.seq_lens_np.tolist()[:num_tokens] |
1570 | | - |
1571 | | - common_attn_metadata = CommonAttentionMetadata( |
1572 | | - query_start_loc=query_start_loc, seq_lens=seq_lens) |
1573 | | - |
1574 | | - self.query_lens = torch.from_numpy(num_scheduled_tokens) |
1575 | | - |
1576 | | - block_table = self.input_batch.block_table[0].block_table |
1577 | | - block_table[:num_reqs, 0] = torch.arange(1, |
1578 | | - num_reqs + 1, |
1579 | | - device=block_table.device, |
1580 | | - dtype=block_table.dtype) |
1581 | | - |
1582 | | - self.slot_mapping[:num_tokens] = torch.arange( |
1583 | | - 1, |
1584 | | - num_tokens + 1, |
1585 | | - device=self.slot_mapping.device, |
1586 | | - dtype=self.slot_mapping.dtype) * self.block_size + 1 |
1587 | | - |
1588 | | - attn_metadata = self.attn_metadata_builder.build( |
1589 | | - num_reqs=num_reqs, |
| 1566 | + attn_metadata = self.attn_metadata_builder.build_dummy_metadata( |
1590 | 1567 | num_actual_tokens=num_tokens, |
1591 | | - max_query_len=num_tokens, |
1592 | | - common_prefix_len=0, |
1593 | | - common_attn_metadata=common_attn_metadata, |
| 1568 | + num_reqs=num_reqs, |
| 1569 | + num_scheduled_tokens=num_scheduled_tokens, |
| 1570 | + attn_state=attn_state, |
1594 | 1571 | ) |
1595 | 1572 |
|
1596 | | - attn_metadata.attn_state = AscendAttentionState.DecodeOnly |
1597 | | - |
1598 | 1573 | with self.maybe_dummy_run_with_lora(self.lora_config, |
1599 | 1574 | num_scheduled_tokens): |
1600 | 1575 | model = self.model |
@@ -1977,6 +1952,7 @@ def capture_model(self) -> None: |
1977 | 1952 | # TODO(zzzzwwjj): Check dummy_run with ACL Graph and full graph mode |
1978 | 1953 | with graph_capture(device=self.device): |
1979 | 1954 | skip_attn = not self.vllm_config.compilation_config.full_cuda_graph |
| 1955 | + # TODO: Make sure passing attn_state to _dummy_run in the future |
1980 | 1956 | for num_tokens in reversed(self.aclgraph_batch_sizes): |
1981 | 1957 | for _ in range(self.vllm_config.compilation_config. |
1982 | 1958 | cudagraph_num_of_warmups): |
|
0 commit comments