Skip to content

Commit 1598b45

Browse files
committed
Perf: support PIECEWISE cuda graph for PCP
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
1 parent e2e2952 commit 1598b45

File tree

4 files changed

+28
-20
lines changed

4 files changed

+28
-20
lines changed

vllm/config/vllm.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,15 @@ def __post_init__(self):
359359
):
360360
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
361361

362+
# prefill context parallel do not support full cudagraphs now.
363+
if self.parallel_config.prefill_context_parallel_size > 1:
364+
logger.warning(
365+
"Prefill context parallel (PCP) is enabled, which is "
366+
"incompatible with full CUDA graphs. Set "
367+
"cudagraph_mode to PIECEWISE."
368+
)
369+
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
370+
362371
# decode context parallel do not support full cudagraphs now.
363372
if self.parallel_config.decode_context_parallel_size > 1:
364373
logger.warning(

vllm/platforms/cuda.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -206,13 +206,6 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
206206
)
207207
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
208208

209-
if (
210-
compilation_config.cudagraph_mode != CUDAGraphMode.NONE
211-
and parallel_config.prefill_context_parallel_size > 1
212-
):
213-
logger.info("Prefill Context Parallel: disabling cudagraphs since PCP.")
214-
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
215-
216209
@classmethod
217210
def get_current_memory_usage(
218211
cls, device: torch.types.Device | None = None

vllm/v1/attention/backends/flashinfer.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,24 +1014,25 @@ def forward(
10141014

10151015
num_actual_tokens = attn_metadata.num_actual_tokens
10161016

1017-
key_across_cp = get_pcp_group().all_gather(key.contiguous(), dim=0)
1018-
value_across_cp = get_pcp_group().all_gather(value.contiguous(), dim=0)
1019-
if (
1020-
self.pcp_world_size > 1
1021-
and attn_metadata.pcp_allgather_restore_idx is not None
1022-
):
1023-
# Reorder kv after cp allgather.
1017+
if (self.pcp_world_size > 1):
1018+
assert attn_metadata.pcp_allgather_restore_idx is not None
1019+
# NOTE(yyj): we must `slice` key and value because pcp_allgather_restore_idx
1020+
# ignores the padding from CUDA Graph. To be optimized for performance!
1021+
key_across_cp = get_pcp_group().all_gather(
1022+
key[:num_actual_tokens].contiguous(), dim=0
1023+
)
1024+
value_across_cp = get_pcp_group().all_gather(
1025+
value[:num_actual_tokens].contiguous(), dim=0
1026+
)
1027+
# Reorder kv after pcp allgather.
10241028
# Note that there are duplicate decoding tokens,
10251029
# but we only save the first one in kvcache.
1026-
key_across_cp = torch.index_select(
1030+
key = torch.index_select(
10271031
key_across_cp, 0, attn_metadata.pcp_allgather_restore_idx
10281032
)
1029-
value_across_cp = torch.index_select(
1033+
value = torch.index_select(
10301034
value_across_cp, 0, attn_metadata.pcp_allgather_restore_idx
10311035
)
1032-
key = key_across_cp
1033-
value = value_across_cp
1034-
10351036
if self.kv_sharing_target_layer_name is None:
10361037
# Reshape the input keys and values and store them in the cache.
10371038
# Skip this if sharing KV cache with an earlier attention layer.

vllm/v1/worker/gpu_model_runner.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2696,7 +2696,12 @@ def execute_model(
26962696
aux_hidden_states = None
26972697

26982698
if self.pcp_world_size > 1:
2699-
hidden_states = get_pcp_group().all_gather(hidden_states, 0)
2699+
# NOTE we must `slice` hidden_states because pcp_allgather_restore_idx
2700+
# ignores the padding from CUDA Graph.
2701+
hidden_states = get_pcp_group().all_gather(
2702+
hidden_states[:num_scheduled_tokens],
2703+
0,
2704+
)
27002705
hidden_states = torch.index_select(
27012706
hidden_states,
27022707
0,

0 commit comments

Comments
 (0)