Skip to content
Merged
2 changes: 1 addition & 1 deletion cmake/external_projects/vllm_flash_attn.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ else()
FetchContent_Declare(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG 4695e6bed5366c41e28c06cd86170166e4f43d00
GIT_TAG 8f468e7da54a8e2f98abfa7c38636aac91c0cba1
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
Expand Down
16 changes: 14 additions & 2 deletions vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ class ChunkedContextMetadata(MLACommonPrefillMetadata.ChunkedContextMetadata):
class MLACommonDecodeMetadata:
block_table: torch.Tensor
seq_lens: torch.Tensor
dcp_tot_seq_lens: Optional[torch.Tensor]


D = TypeVar("D", bound=MLACommonDecodeMetadata)
Expand Down Expand Up @@ -682,10 +683,12 @@ def _build_decode(
query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor,
num_decode_tokens: int,
dcp_tot_seq_lens_device: Optional[torch.Tensor],
) -> MLACommonDecodeMetadata:
return MLACommonDecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens_device,
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
)

def build_for_cudagraph_capture(
Expand Down Expand Up @@ -727,6 +730,7 @@ def build(
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
seq_lens = common_attn_metadata.seq_lens
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
dcp_local_seq_lens = common_attn_metadata.dcp_local_seq_lens

query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]

Expand All @@ -742,7 +746,10 @@ def build(

# Note(hc): update seq_lens of decode reqs under DCP.
if self.dcp_world_size > 1:
seq_lens[:num_decodes] = seq_lens[:num_decodes] // self.dcp_world_size + (
assert dcp_local_seq_lens is not None
dcp_local_seq_lens[:num_decodes] = seq_lens[
:num_decodes
] // self.dcp_world_size + (
self.dcp_rank <= (seq_lens[:num_decodes] - 1) % self.dcp_world_size
)

Expand Down Expand Up @@ -899,10 +906,15 @@ def build(
decode_metadata = self._build_decode(
block_table_tensor=block_table_tensor[:num_decodes, ...],
seq_lens_cpu=seq_lens_cpu[:num_decodes],
seq_lens_device=seq_lens[:num_decodes],
seq_lens_device=dcp_local_seq_lens[:num_decodes]
if self.dcp_world_size > 1 and dcp_local_seq_lens is not None
else seq_lens[:num_decodes],
query_start_loc_cpu=query_start_loc_cpu[: num_decodes + 1],
query_start_loc_device=query_start_loc[: num_decodes + 1],
num_decode_tokens=num_decode_tokens,
dcp_tot_seq_lens_device=seq_lens[:num_decodes]
if self.dcp_world_size > 1
else None,
)

attn_metadata = self.metadata_cls(
Expand Down
16 changes: 7 additions & 9 deletions vllm/v1/attention/backends/mla/flashattn_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
get_flash_attn_version,
)
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_dcp_group
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (
MLACommonBackend,
Expand Down Expand Up @@ -107,12 +106,6 @@ def __init__(
# pre-allocated during capture.
self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH

# TODO(lucas): Until we add support for the DCP custom masking we need
# to restrict decodes to q_len == 1 when DCP is enabled.
self.reorder_batch_threshold = (
1 if get_dcp_group().world_size > 1 else self.reorder_batch_threshold
)

def _schedule_decode(
self, num_reqs, cu_query_lens, max_query_len, seqlens, max_seq_len, causal
):
Expand All @@ -121,7 +114,7 @@ def _schedule_decode(
batch_size=num_reqs,
max_seqlen_q=max_query_len,
max_seqlen_k=max_seq_len,
num_heads_q=self.num_heads,
num_heads_q=self.num_heads * self.dcp_world_size,
num_heads_kv=1,
headdim=self.mla_dims.qk_rope_head_dim,
cache_seqlens=seqlens,
Expand All @@ -142,10 +135,11 @@ def _build_decode(
query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor,
num_decode_tokens: int,
dcp_tot_seq_lens_device: Optional[torch.Tensor],
) -> FlashAttnMLADecodeMetadata:
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
max_query_len = query_lens_cpu.max().item()
max_seq_len = seq_lens_cpu.max().item()
max_seq_len = seq_lens_device.max().item()

scheduler_metadata = self._schedule_decode(
num_reqs=seq_lens_cpu.numel(),
Expand Down Expand Up @@ -188,6 +182,7 @@ def _build_decode(
max_seq_len=max_seq_len,
scheduler_metadata=scheduler_metadata,
max_num_splits=max_num_splits,
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
)


Expand Down Expand Up @@ -289,6 +284,9 @@ def _forward_decode(
fa_version=3, # only version 3 is supported
scheduler_metadata=attn_metadata.decode.scheduler_metadata,
num_splits=attn_metadata.decode.max_num_splits,
cp_world_size=self.dcp_world_size,
cp_rank=self.dcp_rank,
cp_tot_seqused_k=attn_metadata.decode.dcp_tot_seq_lens,
)

if self.need_to_return_lse_for_decode:
Expand Down
2 changes: 2 additions & 0 deletions vllm/v1/attention/backends/mla/flashmla.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def _build_decode(
query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor,
num_decode_tokens: int,
dcp_tot_seq_lens_device: Optional[torch.Tensor],
) -> FlashMLADecodeMetadata:
tile_scheduler_metadata, num_splits = get_mla_metadata(
seq_lens_device,
Expand Down Expand Up @@ -146,6 +147,7 @@ def _build_decode(
seq_lens=seq_lens_device,
tile_scheduler_metadata=tile_scheduler_metadata,
num_splits=num_splits,
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
)


Expand Down
2 changes: 2 additions & 0 deletions vllm/v1/attention/backends/mla/rocm_aiter_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def _build_decode(
query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor,
num_decode_tokens: int,
dcp_tot_seq_lens_device: Optional[torch.Tensor],
) -> AiterMLADecodeMetadata:
page_size = self.kv_cache_spec.block_size
block_table_bounds = (seq_lens_device + page_size - 1) // page_size
Expand Down Expand Up @@ -174,6 +175,7 @@ def _build_decode(
paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len=paged_kv_last_page_len,
qo_indptr=qo_indptr,
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
)

return attn_metadata
Expand Down
3 changes: 3 additions & 0 deletions vllm/v1/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ class CommonAttentionMetadata:
# Needed by CrossAttentionBuilder
encoder_seq_lens: Optional[np.ndarray] = None

dcp_local_seq_lens: Optional[torch.Tensor] = None
"""Sequence lengths of the local rank in decode context parallelism world"""


def slice_query_start_locs(
query_start_loc: torch.Tensor,
Expand Down
2 changes: 2 additions & 0 deletions vllm/v1/spec_decode/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,7 @@ def prepare_inputs_padded(
block_table_tensor=common_attn_metadata.block_table_tensor,
slot_mapping=common_attn_metadata.slot_mapping[token_indices],
causal=True,
dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
)

token_indices_to_sample = (
Expand Down Expand Up @@ -868,6 +869,7 @@ def prepare_inputs(
block_table_tensor=common_attn_metadata.block_table_tensor,
slot_mapping=common_attn_metadata.slot_mapping[token_indices],
causal=True,
dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
)

return spec_common_attn_metadata, token_indices
Expand Down
15 changes: 14 additions & 1 deletion vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,10 @@ def __init__(
self.max_num_reqs + 1, dtype=torch.int32
)
self.seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
if self.dcp_world_size > 1:
self.dcp_local_seq_lens = self._make_buffer(
self.max_num_reqs, dtype=torch.int32
)
# Because inputs_embeds may be bfloat16 and we don't need a numpy
# version of this tensor, avoid a RuntimeError by not creating a
# numpy buffer.
Expand Down Expand Up @@ -581,7 +585,10 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
# NOTE(lucas): currently no backend supports the custom masking
# required for DCP with q_len > 1, so we assert here. Remove this
# assert once the custom mask is support is added to FA3.
if self.dcp_world_size > 1:
if (
self.dcp_world_size > 1
and envs.VLLM_ATTENTION_BACKEND != "FLASH_ATTN_MLA"
):
assert self.reorder_batch_threshold == 1, (
"DCP not support reorder_batch_threshold > 1 now."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since only flash_attn_mla support custom mask, we can't just remove this assert right now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sense. I'll make a whitelist here for FA3 MLA

)
Expand Down Expand Up @@ -1335,6 +1342,9 @@ def _prepare_inputs(
num_logits_indices=logits_indices.size(0),
causal=True,
encoder_seq_lens=encoder_seq_lens,
dcp_local_seq_lens=self.dcp_local_seq_lens.gpu[:num_reqs]
if self.dcp_world_size > 1
else None,
)

if self.speculative_config and spec_decode_common_attn_metadata is None:
Expand Down Expand Up @@ -3309,6 +3319,9 @@ def _dummy_run(
kv_cache_group_id
].slot_mapping.gpu[:num_tokens],
causal=True,
dcp_local_seq_lens=self.dcp_local_seq_lens.gpu[:num_reqs]
if self.dcp_world_size > 1
else None,
)
for attn_group in self.attn_groups[kv_cache_group_id]:
if ubatch_slices is not None:
Expand Down