Skip to content

Commit 9f8290e

Browse files
Merge pull request #3 from pisceskkk/long_seq_dev
[Feature] support multi-requests
2 parents 3f73536 + 51cb1f2 commit 9f8290e

File tree

9 files changed

+258
-237
lines changed

9 files changed

+258
-237
lines changed

vllm/attention/backends/abstract.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,9 @@ class AttentionImpl(ABC, Generic[T]):
264264
dcp_world_size: int
265265
dcp_rank: int
266266

267+
cp_world_size: int
268+
cp_rank: int
269+
267270
def __new__(cls, *args, **kwargs):
268271
# use __new__ so that all subclasses will call this
269272
self = super().__new__(cls)

vllm/model_executor/layers/fused_moe/config.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77

88
import vllm.envs as envs
99
from vllm.config import ParallelConfig
10-
from vllm.distributed import get_dp_group, get_tensor_model_parallel_rank, get_context_model_parallel_rank
10+
from vllm.distributed import (get_context_model_parallel_rank, get_dp_group,
11+
get_tensor_model_parallel_rank)
1112
from vllm.logger import init_logger
1213
from vllm.model_executor.layers.quantization.utils.quant_utils import (
1314
GroupShape)
@@ -604,7 +605,8 @@ def make(tp_size_: int, dp_size_: int, cp_size_: int,
604605
level's of parallelism to use in the fused moe layer.
605606
606607
Args:
607-
tp_size_ (int): `tp_size` pa use_ep = (dp_size_ * tp_size_ssed into the FusedMoE constructor.
608+
tp_size_ (int): `tp_size` pa use_ep = (dp_size_ * tp_size_ssed into
609+
the FusedMoE constructor.
608610
dp_size_ (int): `dp_size` passed into the FusedMoE constructor.
609611
vllm_parallel_config (ParallelConfig): vLLM's parallel config
610612
object which contains the `enable_expert_parallel` flag.

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
import vllm.envs as envs
1414
from vllm.config import get_current_vllm_config
1515
from vllm.config.parallel import ExpertPlacementStrategy
16-
from vllm.distributed import (get_dp_group, get_ep_group,
16+
from vllm.distributed import (get_context_model_parallel_world_size,
17+
get_dp_group, get_ep_group,
1718
get_tensor_model_parallel_world_size,
18-
get_context_model_parallel_world_size,
1919
tensor_model_parallel_all_reduce)
2020
from vllm.distributed.eplb.eplb_state import EplbState
2121
from vllm.forward_context import ForwardContext, get_forward_context

vllm/platforms/cuda.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,14 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
192192
"compatible. Set the all_to_all backend to deepep_low_latency "
193193
"to use those kernels instead.")
194194
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
195+
196+
if (compilation_config.cudagraph_mode != CUDAGraphMode.NONE
197+
and parallel_config.context_parallel_size > 1):
198+
logger.info(
199+
"Context Parallel: disabling cudagraphs since CP."
200+
)
201+
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
202+
195203

196204
@classmethod
197205
def get_current_memory_usage(cls,

vllm/v1/attention/backends/flashinfer.py

Lines changed: 63 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
AttentionType)
2222
from vllm.attention.ops.common import cp_lse_ag_out_ar
2323
from vllm.config import CUDAGraphMode, VllmConfig
24-
from vllm.logger import init_logger
2524
from vllm.distributed.parallel_state import get_cp_group
25+
from vllm.logger import init_logger
2626
from vllm.model_executor.layers.quantization.utils.quant_utils import (
2727
QuantKey, kFp8StaticTensorSym, kNvfp4Quant)
2828
from vllm.platforms import current_platform
@@ -239,7 +239,7 @@ class FlashInferMetadata:
239239
paged_kv_indptr_gpu: Optional[torch.Tensor] = None
240240

241241
# For context parallel
242-
cp_kv_recover_idx: Optional[torch.Tensor] = None
242+
cp_allgather_restore_idx: Optional[torch.Tensor] = None
243243

244244

245245
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
@@ -262,9 +262,8 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
262262
self.kv_cache_spec.block_size)
263263
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
264264
max_num_pages = max_num_reqs * max_num_pages_per_req
265-
# NOTE(qcs): Context Parallel do not support graph mode now
266265
self.enable_cuda_graph = (self.compilation_config.cudagraph_mode.\
267-
decode_mode() == CUDAGraphMode.FULL and self.cp_world_size == 1)
266+
decode_mode() == CUDAGraphMode.FULL)
268267
if self.enable_cuda_graph:
269268
# For full cudagraph capture, one `decode_wrapper` for each batch
270269
# size is needed for FlashInfer.
@@ -552,7 +551,7 @@ def build(self,
552551
num_prefills=num_prefills,
553552
num_prefill_tokens=num_prefill_tokens,
554553
use_cascade=use_cascade,
555-
cp_kv_recover_idx=common_attn_metadata.cp_kv_recover_idx,
554+
cp_allgather_restore_idx=common_attn_metadata.cp_allgather_restore_idx,
556555
)
557556

558557
qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu
@@ -599,38 +598,30 @@ def build(self,
599598
qo_indptr_cpu = qo_indptr_cpu[prefill_start:] - qo_indptr_cpu[
600599
prefill_start]
601600
paged_kv_indptr_cpu = paged_kv_indptr_cpu[prefill_start:]
602-
prefill_num_computed_tokens_cpu = num_computed_tokens_cpu[prefill_start:]
601+
prefill_num_computed_tokens_cpu = \
602+
num_computed_tokens_cpu[prefill_start:]
603603
if not attn_metadata.prefill_use_trtllm:
604604
if self.cp_world_size > 1:
605-
# NOTE(qcs): no chunked prefill and prefix caching
605+
assert common_attn_metadata.query_positions is not None
606606
kv_indptr_cpu = qo_indptr_cpu * self.cp_world_size
607607
# init custom mask for head-tail query order
608-
mask_arr = []
609-
q_pos = common_attn_metadata.query_positions
610-
for i in range(num_prefills):
611-
# |----<C>-----|-<Q0>-|-<Q1>-|
612-
# |---<C+Q*cp_world_size>----|
613-
# cp_world_size = 2
614-
# Q = 2
615-
# C = 8
616-
# cur_q_pos = [0,3]
617-
# context_mask_i.shape = (2, 8)
618-
# upper = [0,1,2,3]
619-
# local_mask_i = [[True, False, False, False],
620-
# [True, True, True, True]] # size=(2, 4)
621-
# mask_i.shape = (2, 12)
622-
cur_q_pos = torch.from_numpy(q_pos[qo_indptr_cpu[i]:qo_indptr_cpu[i+1]])
623-
Q = len(cur_q_pos)
624-
C = prefill_num_computed_tokens_cpu[i]
625-
if Q <= 0:
626-
mask_arr.append(torch.zeros(0, dtype=torch.bool))
627-
continue
628-
context_mask_i = torch.ones((Q, C), dtype=torch.bool)
629-
upper = torch.arange(Q*self.cp_world_size)
630-
local_mask_i = (upper.unsqueeze(0) <= cur_q_pos.unsqueeze(1))
631-
mask_i = torch.cat([context_mask_i, local_mask_i], dim=1)
632-
mask_arr.append(mask_i.flatten())
633-
custom_mask = torch.cat(mask_arr, dim=0).to(self.device)
608+
q_pos = torch.from_numpy(
609+
common_attn_metadata.query_positions[
610+
prefill_start:]).long()
611+
kv_lens = prefill_num_computed_tokens_cpu + \
612+
kv_indptr_cpu[1:] - kv_indptr_cpu[:-1]
613+
max_q_lens = int(q_pos.max().item()) + 1
614+
max_kv_lens = int(kv_lens.max().item())
615+
mask = torch.ones(max_q_lens, max_kv_lens,
616+
dtype=torch.bool).tril()
617+
selected_rows = torch.index_select(mask, 0, q_pos)
618+
col_indices = torch.arange(max_kv_lens).expand(q_pos.size(0), -1)
619+
valid_mask = col_indices < torch.repeat_interleave(
620+
kv_lens,
621+
qo_indptr_cpu[1:] - \
622+
qo_indptr_cpu[:-1]
623+
).unsqueeze(1)
624+
custom_mask = selected_rows[valid_mask].to(self.device)
634625

635626
attn_metadata.prefill_wrapper.plan(
636627
qo_indptr_cpu.to(self.device),
@@ -874,6 +865,28 @@ def forward(
874865
# performance to make sure it does not introduce any overhead.
875866

876867
num_actual_tokens = attn_metadata.num_actual_tokens
868+
num_decode_tokens = attn_metadata.num_decode_tokens
869+
num_prefill_tokens = attn_metadata.num_prefill_tokens
870+
871+
key_across_cp = get_cp_group().all_gather(
872+
key.contiguous(), dim=0)
873+
value_across_cp = get_cp_group().all_gather(
874+
value.contiguous(), dim=0)
875+
if (self.cp_world_size > 1
876+
and attn_metadata.cp_allgather_restore_idx is not None):
877+
# Reorder kv after cp allgather.
878+
# Note that there are duplicate decoding tokens,
879+
# but we only save the first one in kvcache.
880+
key_across_cp = torch.index_select(
881+
key_across_cp, 0,
882+
attn_metadata.cp_allgather_restore_idx
883+
)
884+
value_across_cp = torch.index_select(
885+
value_across_cp, 0,
886+
attn_metadata.cp_allgather_restore_idx
887+
)
888+
key = key_across_cp
889+
value = value_across_cp
877890

878891
if self.kv_sharing_target_layer_name is None:
879892
# Reshape the input keys and values and store them in the cache.
@@ -883,17 +896,16 @@ def forward(
883896
# and value[:num_actual_tokens] because the reshape_and_cache_flash
884897
# op uses the slot_mapping's shape to determine the number of
885898
# actual tokens.
886-
if self.cp_world_size == 1:
887-
torch.ops._C_cache_ops.reshape_and_cache_flash(
888-
key,
889-
value,
890-
kv_cache[:, 0],
891-
kv_cache[:, 1],
892-
attn_metadata.slot_mapping,
893-
self.kv_cache_dtype,
894-
layer._k_scale,
895-
layer._v_scale,
896-
)
899+
torch.ops._C_cache_ops.reshape_and_cache_flash(
900+
key,
901+
value,
902+
kv_cache[:, 0],
903+
kv_cache[:, 1],
904+
attn_metadata.slot_mapping,
905+
self.kv_cache_dtype,
906+
layer._k_scale,
907+
layer._v_scale,
908+
)
897909

898910
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
899911
# to process the cache when the kv_cache_dtype is fp8
@@ -913,9 +925,6 @@ def forward(
913925
output.copy_(attn_metadata.cascade_wrapper.run(query, kv_cache))
914926
return output
915927

916-
num_decode_tokens = attn_metadata.num_decode_tokens
917-
num_prefill_tokens = attn_metadata.num_prefill_tokens
918-
919928
stride_order = FlashInferBackend.get_kv_cache_stride_order()
920929
kv_cache_permute = kv_cache.permute(*stride_order)
921930
# Regular attention (common case).
@@ -933,34 +942,15 @@ def forward(
933942
self.logits_soft_cap or 0.0)
934943
assert prefill_wrapper._sm_scale == self.scale
935944
if self.cp_world_size > 1:
936-
key_across_cp = get_cp_group().all_gather(
937-
key[num_decode_tokens:].contiguous(), dim=0)
938-
value_across_cp = get_cp_group().all_gather(
939-
value[num_decode_tokens:].contiguous(), dim=0)
940-
key_across_cp = torch.index_select(
941-
key_across_cp, 0,
942-
attn_metadata.cp_kv_recover_idx
943-
)
944-
value_across_cp = torch.index_select(
945-
value_across_cp, 0,
946-
attn_metadata.cp_kv_recover_idx
947-
)
948-
torch.ops._C_cache_ops.reshape_and_cache_flash(
949-
key_across_cp,
950-
value_across_cp,
951-
kv_cache[:, 0],
952-
kv_cache[:, 1],
953-
attn_metadata.slot_mapping[num_decode_tokens:],
954-
self.kv_cache_dtype,
955-
layer._k_scale,
956-
layer._v_scale,
957-
)
958-
# TODO(qcs): 考虑 chunked prefill/ prefix cache 情况下
959-
# kvcache的获取与拼接
945+
# NOTE(qcs): Allgather causes duplicate decoding tokens.
946+
prefill_key = key[
947+
num_decode_tokens*self.cp_world_size:]
948+
prefill_value = value[
949+
num_decode_tokens*self.cp_world_size:]
960950
prefill_wrapper.run(
961951
prefill_query,
962-
key_across_cp,
963-
value_across_cp,
952+
prefill_key,
953+
prefill_value,
964954
out=output[num_decode_tokens:],
965955
)
966956
else:
@@ -1047,17 +1037,6 @@ def forward(
10471037
or 0.0)
10481038
assert decode_wrapper._sm_scale == self.scale
10491039
if self.cp_world_size > 1:
1050-
torch.ops._C_cache_ops.reshape_and_cache_flash(
1051-
key[:num_decode_tokens],
1052-
value[:num_decode_tokens],
1053-
kv_cache[:, 0],
1054-
kv_cache[:, 1],
1055-
attn_metadata.slot_mapping[:num_decode_tokens],
1056-
self.kv_cache_dtype,
1057-
layer._k_scale,
1058-
layer._v_scale,
1059-
)
1060-
kv_cache_permute = kv_cache.permute(*stride_order)
10611040
out, lse = decode_wrapper.run(
10621041
decode_query,
10631042
kv_cache_permute,

vllm/v1/attention/backends/utils.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ class CommonAttentionMetadata:
8383

8484
# Needed by custom mask calc for context parallelism
8585
query_positions: Optional[np.ndarray] = None
86-
cp_kv_recover_idx: Optional[torch.Tensor] = None
86+
cp_allgather_restore_idx: Optional[torch.Tensor] = None
8787

8888
def slice_query_start_locs(
8989
query_start_loc: torch.Tensor,
@@ -139,10 +139,13 @@ def _make_metadata_with_slice(
139139
block_table_tensor = attn_metadata.block_table_tensor[request_slice]
140140
slot_mapping = attn_metadata.slot_mapping[token_slice]
141141

142+
# TODO(qcs): check if we can split query_positions and
143+
# cp_kv_recover_idx as following approach
142144
query_positions = attn_metadata.query_positions[token_slice] \
143145
if attn_metadata.query_positions is not None else None
144-
cp_kv_recover_idx = attn_metadata.cp_kv_recover_idx[token_slice] \
145-
if attn_metadata.cp_kv_recover_idx is not None else None
146+
cp_allgather_restore_idx = attn_metadata.cp_allgather_restore_idx[
147+
token_slice] if attn_metadata.cp_allgather_restore_idx is not None \
148+
else None
146149

147150
return CommonAttentionMetadata(
148151
query_start_loc=query_start_loc,
@@ -157,7 +160,7 @@ def _make_metadata_with_slice(
157160
block_table_tensor=block_table_tensor,
158161
slot_mapping=slot_mapping,
159162
query_positions=query_positions,
160-
cp_kv_recover_idx=cp_kv_recover_idx,
163+
cp_allgather_restore_idx=cp_allgather_restore_idx,
161164
)
162165

163166

vllm/v1/executor/multiproc_executor.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@
2828
destroy_model_parallel)
2929
from vllm.distributed.device_communicators.shm_broadcast import (Handle,
3030
MessageQueue)
31-
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
32-
get_pp_group, get_tp_group,
33-
get_cp_group)
31+
from vllm.distributed.parallel_state import (get_cp_group, get_dp_group,
32+
get_ep_group, get_pp_group,
33+
get_tp_group)
3434
from vllm.logger import init_logger
3535
from vllm.multimodal import MULTIMODAL_REGISTRY
3636
from vllm.multimodal.cache import worker_receiver_cache_from_config
@@ -64,7 +64,8 @@ def _init_executor(self) -> None:
6464
tensor_parallel_size = self.parallel_config.tensor_parallel_size
6565
pp_parallel_size = self.parallel_config.pipeline_parallel_size
6666
context_parallel_size = self.parallel_config.context_parallel_size
67-
assert self.world_size == tensor_parallel_size * pp_parallel_size * context_parallel_size, (
67+
assert self.world_size == tensor_parallel_size * pp_parallel_size * \
68+
context_parallel_size, (
6869
f"world_size ({self.world_size}) must be equal to the "
6970
f"tensor_parallel_size ({tensor_parallel_size}) x pipeline"
7071
f"_parallel_size ({pp_parallel_size}) x context"
@@ -345,7 +346,8 @@ def _get_output_rank(self) -> int:
345346
# 16-23, PP rank 2
346347
# 24-31, PP rank 3
347348
# so world_size - tp_size = 32 - 8 = 24 should be PP rank = -1 (i.e. 3)
348-
return self.world_size - self.parallel_config.tensor_parallel_size * self.parallel_config.context_parallel_size
349+
return self.world_size - self.parallel_config.tensor_parallel_size * \
350+
self.parallel_config.context_parallel_size
349351

350352

351353
@dataclass

vllm/v1/worker/block_table.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy as np
66
import torch
77

8-
from vllm.distributed import get_dcp_group, get_cp_group
8+
from vllm.distributed import get_cp_group, get_dcp_group
99
from vllm.logger import init_logger
1010
from vllm.utils import cdiv
1111
from vllm.v1.utils import CpuGpuBuffer
@@ -92,18 +92,21 @@ def compute_slot_mapping(self, req_indices: np.ndarray,
9292

9393
# Use a "virtual block" which equals to world_size * block_size
9494
# for block_table_indices calculation.
95-
virtual_block_size = self.block_size * self.dcp_world_size * self.cp_world_size
95+
virtual_block_size = self.block_size * self.dcp_world_size * \
96+
self.cp_world_size
9697
block_table_indices = (req_indices * self.max_num_blocks_per_req +
9798
positions // virtual_block_size)
9899
block_numbers = self.block_table.np.ravel()[block_table_indices]
99100
# Use virtual_block_size for mask calculation, which marks local
100101
# tokens.
101102
virtual_block_offsets = positions % virtual_block_size
102-
self.current_rank = self.dcp_world_size * self.cp_rank + self.dcp_rank
103-
mask = (virtual_block_offsets %
104-
(self.dcp_world_size * self.cp_world_size) == self.current_rank)
103+
self.current_rank = self.dcp_world_size * self.cp_rank + \
104+
self.dcp_rank
105+
mask = (virtual_block_offsets % (self.dcp_world_size * \
106+
self.cp_world_size) == self.current_rank)
105107
# Calculate local block_offsets
106-
block_offsets = virtual_block_offsets // (self.dcp_world_size * self.cp_world_size)
108+
block_offsets = virtual_block_offsets // \
109+
(self.dcp_world_size * self.cp_world_size)
107110
# Calculate slot_mapping
108111
slot_mapping = block_numbers * self.block_size + block_offsets
109112
# Write final slots, use -1 for not-local
@@ -147,8 +150,12 @@ def _make_buffer(self, *size: Union[int, torch.SymInt],
147150
device=self.device,
148151
pin_memory=self.pin_memory)
149152

150-
def get_split_computed_tokens(self, num_computed_tokens: np.ndarray) -> list[list[list[int]]]:
151-
"Splits computed token counts across dcp and sp dimensions for distributed allocation."
153+
def get_split_computed_tokens(self, num_computed_tokens: np.ndarray) \
154+
-> list[list[list[int]]]:
155+
"""
156+
Splits computed token counts across dcp and sp dimensions for
157+
distributed allocation.
158+
"""
152159
num_requests = len(num_computed_tokens)
153160
num_computed_tokens_of_dcp_sp = [[
154161
[0] * self.dcp_world_size for _ in range(self.cp_world_size)

0 commit comments

Comments
 (0)