diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index a990cb2f1a97..3dd817a0564f 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -359,6 +359,8 @@ class ChunkedContextMetadata( class MLACommonDecodeMetadata: block_table: torch.Tensor seq_lens: torch.Tensor + query_base_positions: Optional[torch.Tensor] = field(default=None, + kw_only=True) D = TypeVar("D", bound=MLACommonDecodeMetadata) @@ -615,9 +617,19 @@ def _build_decode(self, block_table_tensor: torch.Tensor, query_start_loc_cpu: torch.Tensor, query_start_loc_device: torch.Tensor, num_decode_tokens: int) -> MLACommonDecodeMetadata: + + # Compute DCP query base positions if using DCP + query_base_positions = None + + if self.dcp_world_size > 1: + query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + query_base_positions = (seq_lens_cpu - query_lens).to( + seq_lens_device.device) + return MLACommonDecodeMetadata( block_table=block_table_tensor, seq_lens=seq_lens_device, + query_base_positions=query_base_positions, ) def build_for_cudagraph_capture( diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 472095e13615..c2f727ba77c5 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -11,7 +11,6 @@ from vllm.attention.utils.fa_utils import (flash_attn_supports_mla, get_flash_attn_version) from vllm.config import VllmConfig -from vllm.distributed.parallel_state import get_dcp_group from vllm.logger import init_logger from vllm.v1.attention.backends.mla.common import (MLACommonBackend, MLACommonDecodeMetadata, @@ -99,11 +98,6 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], # pre-allocated during capture. self.max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH - # TODO(lucas): Until we add support for the DCP custom masking we need - # to restrict decodes to q_len == 1 when DCP is enabled. - self.__class__.reorder_batch_threshold = 1 \ - if get_dcp_group().world_size > 1 else self.reorder_batch_threshold - def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens, max_seq_len, causal): if self.fa_aot_schedule: @@ -262,6 +256,9 @@ def _forward_decode( fa_version=3, # only version 3 is supported scheduler_metadata=attn_metadata.decode.scheduler_metadata, num_splits=attn_metadata.decode.max_num_splits, + dcp_rank=self.dcp_rank, + dcp_world_size=self.dcp_world_size, + query_base_positions=attn_metadata.decode.query_base_positions, ) if self.need_to_return_lse_for_decode: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d4afaf51e6e8..c8c8f9d68b04 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -474,12 +474,6 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: return if self.reorder_batch_threshold is not None: - # NOTE(lucas): currently no backend supports the custom masking - # required for DCP with q_len > 1, so we assert here. Remove this - # assert once the custom mask is support is added to FA3. - if self.dcp_world_size > 1: - assert self.reorder_batch_threshold == 1, \ - "DCP not support reorder_batch_threshold > 1 now." reorder_batch_to_split_decodes_and_prefills( self.input_batch, scheduler_output,