diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index dd8c638394..292294b937 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -132,6 +132,8 @@ class AscendMetadata: # For logging. num_input_tokens: int = 0 # Number of tokens including padding. + with_prefill_across_dp: bool = False + class AscendAttentionMetadataBuilder: @@ -142,8 +144,12 @@ def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: return False - def build(self, num_reqs, num_actual_tokens, max_query_len, - common_prefix_len): + def build(self, + num_reqs, + num_actual_tokens, + max_query_len, + common_prefix_len, + with_prefill_across_dp: bool = False): block_table = self.runner.input_batch.block_table[0].get_device_tensor( ) @@ -160,15 +166,17 @@ def build(self, num_reqs, num_actual_tokens, max_query_len, query_start_loc = query_start_loc_cpu.to(self.runner.device, non_blocking=True) - attn_metadata = AscendMetadata(num_actual_tokens=num_actual_tokens, - block_tables=block_table, - query_start_loc=query_start_loc, - query_lens=query_lens, - seq_lens=seq_lens, - max_query_len=max_query_len, - slot_mapping=slot_mapping, - attn_mask=attn_mask, - attn_state=attn_state) + attn_metadata = AscendMetadata( + num_actual_tokens=num_actual_tokens, + block_tables=block_table, + query_start_loc=query_start_loc, + query_lens=query_lens, + seq_lens=seq_lens, + max_query_len=max_query_len, + slot_mapping=slot_mapping, + attn_mask=attn_mask, + attn_state=attn_state, + with_prefill_across_dp=with_prefill_across_dp) return attn_metadata