diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake index e6686275cabb..d4908772c69e 100644 --- a/cmake/external_projects/vllm_flash_attn.cmake +++ b/cmake/external_projects/vllm_flash_attn.cmake @@ -38,7 +38,7 @@ else() FetchContent_Declare( vllm-flash-attn GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG 4695e6bed5366c41e28c06cd86170166e4f43d00 + GIT_TAG 8f468e7da54a8e2f98abfa7c38636aac91c0cba1 GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 3fb00f5917ea..af396c2b4103 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -370,6 +370,7 @@ class ChunkedContextMetadata(MLACommonPrefillMetadata.ChunkedContextMetadata): class MLACommonDecodeMetadata: block_table: torch.Tensor seq_lens: torch.Tensor + dcp_tot_seq_lens: Optional[torch.Tensor] D = TypeVar("D", bound=MLACommonDecodeMetadata) @@ -682,10 +683,12 @@ def _build_decode( query_start_loc_cpu: torch.Tensor, query_start_loc_device: torch.Tensor, num_decode_tokens: int, + dcp_tot_seq_lens_device: Optional[torch.Tensor], ) -> MLACommonDecodeMetadata: return MLACommonDecodeMetadata( block_table=block_table_tensor, seq_lens=seq_lens_device, + dcp_tot_seq_lens=dcp_tot_seq_lens_device, ) def build_for_cudagraph_capture( @@ -727,6 +730,7 @@ def build( query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu seq_lens = common_attn_metadata.seq_lens seq_lens_cpu = common_attn_metadata.seq_lens_cpu + dcp_local_seq_lens = common_attn_metadata.dcp_local_seq_lens query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] @@ -742,7 +746,10 @@ def build( # Note(hc): update seq_lens of decode reqs under DCP. if self.dcp_world_size > 1: - seq_lens[:num_decodes] = seq_lens[:num_decodes] // self.dcp_world_size + ( + assert dcp_local_seq_lens is not None + dcp_local_seq_lens[:num_decodes] = seq_lens[ + :num_decodes + ] // self.dcp_world_size + ( self.dcp_rank <= (seq_lens[:num_decodes] - 1) % self.dcp_world_size ) @@ -899,10 +906,15 @@ def build( decode_metadata = self._build_decode( block_table_tensor=block_table_tensor[:num_decodes, ...], seq_lens_cpu=seq_lens_cpu[:num_decodes], - seq_lens_device=seq_lens[:num_decodes], + seq_lens_device=dcp_local_seq_lens[:num_decodes] + if self.dcp_world_size > 1 and dcp_local_seq_lens is not None + else seq_lens[:num_decodes], query_start_loc_cpu=query_start_loc_cpu[: num_decodes + 1], query_start_loc_device=query_start_loc[: num_decodes + 1], num_decode_tokens=num_decode_tokens, + dcp_tot_seq_lens_device=seq_lens[:num_decodes] + if self.dcp_world_size > 1 + else None, ) attn_metadata = self.metadata_cls( diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index c0c2dbe1f961..c043990ffcc6 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -17,7 +17,6 @@ get_flash_attn_version, ) from vllm.config import VllmConfig -from vllm.distributed.parallel_state import get_dcp_group from vllm.logger import init_logger from vllm.v1.attention.backends.mla.common import ( MLACommonBackend, @@ -107,12 +106,6 @@ def __init__( # pre-allocated during capture. self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH - # TODO(lucas): Until we add support for the DCP custom masking we need - # to restrict decodes to q_len == 1 when DCP is enabled. - self.reorder_batch_threshold = ( - 1 if get_dcp_group().world_size > 1 else self.reorder_batch_threshold - ) - def _schedule_decode( self, num_reqs, cu_query_lens, max_query_len, seqlens, max_seq_len, causal ): @@ -121,7 +114,7 @@ def _schedule_decode( batch_size=num_reqs, max_seqlen_q=max_query_len, max_seqlen_k=max_seq_len, - num_heads_q=self.num_heads, + num_heads_q=self.num_heads * self.dcp_world_size, num_heads_kv=1, headdim=self.mla_dims.qk_rope_head_dim, cache_seqlens=seqlens, @@ -142,10 +135,11 @@ def _build_decode( query_start_loc_cpu: torch.Tensor, query_start_loc_device: torch.Tensor, num_decode_tokens: int, + dcp_tot_seq_lens_device: Optional[torch.Tensor], ) -> FlashAttnMLADecodeMetadata: query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] max_query_len = query_lens_cpu.max().item() - max_seq_len = seq_lens_cpu.max().item() + max_seq_len = seq_lens_device.max().item() scheduler_metadata = self._schedule_decode( num_reqs=seq_lens_cpu.numel(), @@ -188,6 +182,7 @@ def _build_decode( max_seq_len=max_seq_len, scheduler_metadata=scheduler_metadata, max_num_splits=max_num_splits, + dcp_tot_seq_lens=dcp_tot_seq_lens_device, ) @@ -289,6 +284,9 @@ def _forward_decode( fa_version=3, # only version 3 is supported scheduler_metadata=attn_metadata.decode.scheduler_metadata, num_splits=attn_metadata.decode.max_num_splits, + cp_world_size=self.dcp_world_size, + cp_rank=self.dcp_rank, + cp_tot_seqused_k=attn_metadata.decode.dcp_tot_seq_lens, ) if self.need_to_return_lse_for_decode: diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index f4f82f1cce91..e0f4a7f0382b 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -106,6 +106,7 @@ def _build_decode( query_start_loc_cpu: torch.Tensor, query_start_loc_device: torch.Tensor, num_decode_tokens: int, + dcp_tot_seq_lens_device: Optional[torch.Tensor], ) -> FlashMLADecodeMetadata: tile_scheduler_metadata, num_splits = get_mla_metadata( seq_lens_device, @@ -146,6 +147,7 @@ def _build_decode( seq_lens=seq_lens_device, tile_scheduler_metadata=tile_scheduler_metadata, num_splits=num_splits, + dcp_tot_seq_lens=dcp_tot_seq_lens_device, ) diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 54ebf071d96f..195b05e0a301 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -116,6 +116,7 @@ def _build_decode( query_start_loc_cpu: torch.Tensor, query_start_loc_device: torch.Tensor, num_decode_tokens: int, + dcp_tot_seq_lens_device: Optional[torch.Tensor], ) -> AiterMLADecodeMetadata: page_size = self.kv_cache_spec.block_size block_table_bounds = (seq_lens_device + page_size - 1) // page_size @@ -174,6 +175,7 @@ def _build_decode( paged_kv_indices=paged_kv_indices, paged_kv_last_page_len=paged_kv_last_page_len, qo_indptr=qo_indptr, + dcp_tot_seq_lens=dcp_tot_seq_lens_device, ) return attn_metadata diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 638063a8f6f8..7c6940d9b15d 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -93,6 +93,9 @@ class CommonAttentionMetadata: # Needed by CrossAttentionBuilder encoder_seq_lens: Optional[np.ndarray] = None + dcp_local_seq_lens: Optional[torch.Tensor] = None + """Sequence lengths of the local rank in decode context parallelism world""" + def slice_query_start_locs( query_start_loc: torch.Tensor, diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index d597ce68ffe1..356ddfc9d986 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -597,6 +597,7 @@ def prepare_inputs_padded( block_table_tensor=common_attn_metadata.block_table_tensor, slot_mapping=common_attn_metadata.slot_mapping[token_indices], causal=True, + dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens, ) token_indices_to_sample = ( @@ -868,6 +869,7 @@ def prepare_inputs( block_table_tensor=common_attn_metadata.block_table_tensor, slot_mapping=common_attn_metadata.slot_mapping[token_indices], causal=True, + dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens, ) return spec_common_attn_metadata, token_indices diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ea3b18b447f3..35228a5d5284 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -398,6 +398,10 @@ def __init__( self.max_num_reqs + 1, dtype=torch.int32 ) self.seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32) + if self.dcp_world_size > 1: + self.dcp_local_seq_lens = self._make_buffer( + self.max_num_reqs, dtype=torch.int32 + ) # Because inputs_embeds may be bfloat16 and we don't need a numpy # version of this tensor, avoid a RuntimeError by not creating a # numpy buffer. @@ -581,7 +585,10 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: # NOTE(lucas): currently no backend supports the custom masking # required for DCP with q_len > 1, so we assert here. Remove this # assert once the custom mask is support is added to FA3. - if self.dcp_world_size > 1: + if ( + self.dcp_world_size > 1 + and envs.VLLM_ATTENTION_BACKEND != "FLASH_ATTN_MLA" + ): assert self.reorder_batch_threshold == 1, ( "DCP not support reorder_batch_threshold > 1 now." ) @@ -1335,6 +1342,9 @@ def _prepare_inputs( num_logits_indices=logits_indices.size(0), causal=True, encoder_seq_lens=encoder_seq_lens, + dcp_local_seq_lens=self.dcp_local_seq_lens.gpu[:num_reqs] + if self.dcp_world_size > 1 + else None, ) if self.speculative_config and spec_decode_common_attn_metadata is None: @@ -3309,6 +3319,9 @@ def _dummy_run( kv_cache_group_id ].slot_mapping.gpu[:num_tokens], causal=True, + dcp_local_seq_lens=self.dcp_local_seq_lens.gpu[:num_reqs] + if self.dcp_world_size > 1 + else None, ) for attn_group in self.attn_groups[kv_cache_group_id]: if ubatch_slices is not None: