Skip to content

Commit d09cda7

Browse files
FENPpisceskkkLookAround0301
committed
Init support PCP with FlashInfer.
Co-authored-by: FENP <yuanyongjie.yyj@antgroup.com> Co-authored-by: QiuChunshuo <qiuchunshuo@huawei.com> Co-authored-by: LookAround <lixushi@huawei.com> Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com> Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com> Signed-off-by: LookAround <lixushi@huawei.com>
1 parent 17c540a commit d09cda7

File tree

21 files changed

+655
-109
lines changed

21 files changed

+655
-109
lines changed

tests/distributed/test_context_parallel.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,15 @@ class ParallelSetup(NamedTuple):
3030
tp_size: int
3131
pp_size: int
3232
dcp_size: int
33+
pcp_size: int
3334
eager_mode: bool
3435
chunked_prefill: bool
3536

3637

3738
class CPTestOptions(NamedTuple):
3839
multi_node_only: bool
3940
load_format: str | None = None
41+
attn_backend: str = "FLASH_ATTN"
4042

4143

4244
@dataclass
@@ -52,20 +54,25 @@ def detailed(
5254
tp_base: int = 4,
5355
pp_base: int = 1,
5456
dcp_base: int = 1,
57+
pcp_base: int = 1,
5558
multi_node_only: bool = False,
5659
runner: RunnerOption = "auto",
5760
load_format: str | None = None,
61+
attn_backend: str = "FLASH_ATTN",
5862
):
5963
parallel_setups = []
6064
for eager_mode_val in [False]:
6165
for pp_multiplier in [1]:
62-
for dcp_multiplier in [0.5, 1]:
66+
# TODO(qcs): Test the effect of mixed activation
67+
# when CP and DCP are compatible.
68+
for pcp_multiplier, dcp_multiplier in zip([1, 2, 1], [0.5, 1, 1]):
6369
for chunked_prefill_val in [True]:
6470
parallel_setups.append(
6571
ParallelSetup(
6672
tp_size=tp_base,
6773
pp_size=pp_multiplier * pp_base,
6874
dcp_size=int(dcp_multiplier * tp_base),
75+
pcp_size=int(pcp_multiplier * pcp_base),
6976
eager_mode=eager_mode_val,
7077
chunked_prefill=chunked_prefill_val,
7178
)
@@ -75,7 +82,9 @@ def detailed(
7582
distributed_backends=["mp"],
7683
runner=runner,
7784
test_options=CPTestOptions(
78-
multi_node_only=multi_node_only, load_format=load_format
85+
multi_node_only=multi_node_only,
86+
load_format=load_format,
87+
attn_backend=attn_backend,
7988
),
8089
)
8190

@@ -108,11 +117,12 @@ def _compare_cp_with_tp(
108117
tp_size,
109118
pp_size,
110119
dcp_size,
120+
pcp_size,
111121
eager_mode,
112122
chunked_prefill,
113123
) = parallel_setup
114124

115-
multi_node_only, load_format = test_options
125+
multi_node_only, load_format, attn_backend = test_options
116126

117127
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
118128
model_info.check_transformers_version(on_fail="skip")
@@ -155,7 +165,7 @@ def _compare_cp_with_tp(
155165
"--max-model-len",
156166
"2048",
157167
"--max-num-seqs",
158-
"8",
168+
"16",
159169
]
160170
if chunked_prefill:
161171
common_args.append("--enable-chunked-prefill")
@@ -172,6 +182,10 @@ def _compare_cp_with_tp(
172182
if hf_overrides:
173183
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
174184

185+
cp_env = tp_env = {
186+
"VLLM_ATTENTION_BACKEND": attn_backend,
187+
}
188+
175189
cp_args = [
176190
*common_args,
177191
"--tensor-parallel-size",
@@ -180,6 +194,8 @@ def _compare_cp_with_tp(
180194
str(pp_size),
181195
"--decode-context-parallel-size",
182196
str(dcp_size),
197+
"--prefill-context-parallel-size",
198+
str(pcp_size),
183199
"--distributed-executor-backend",
184200
distributed_backend,
185201
]
@@ -198,19 +214,24 @@ def _compare_cp_with_tp(
198214
model_id,
199215
cp_args,
200216
tp_args,
217+
cp_env,
218+
tp_env,
201219
method=method,
202220
max_wait_seconds=720,
203221
)
204222

205223

206224
CP_TEXT_GENERATION_MODELS = {
225+
# [MLA attention only]
207226
"deepseek-ai/DeepSeek-V2-Lite-Chat": [
208227
CPTestSettings.detailed(),
209228
CPTestSettings.detailed(tp_base=2),
210229
],
211230
"bigcode/gpt_bigcode-santacoder": [
212231
CPTestSettings.detailed(),
213232
CPTestSettings.detailed(tp_base=2),
233+
CPTestSettings.detailed(attn_backend="FLASHINFER"),
234+
CPTestSettings.detailed(tp_base=2, attn_backend="FLASHINFER"),
214235
],
215236
}
216237

vllm/attention/backends/abstract.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,9 @@ class AttentionImpl(ABC, Generic[T]):
127127
dcp_world_size: int
128128
dcp_rank: int
129129

130+
pcp_world_size: int
131+
pcp_rank: int
132+
130133
def __new__(cls, *args, **kwargs):
131134
# use __new__ so that all subclasses will call this
132135
self = super().__new__(cls)
@@ -139,6 +142,16 @@ def __new__(cls, *args, **kwargs):
139142
# DCP might not be initialized in testing
140143
self.dcp_world_size = 1
141144
self.dcp_rank = 0
145+
try:
146+
from vllm.distributed.parallel_state import get_pcp_group
147+
148+
self.pcp_world_size = get_pcp_group().world_size
149+
self.pcp_rank = get_pcp_group().rank_in_group
150+
except AssertionError:
151+
# PCP might not be initialized in testing
152+
self.pcp_world_size = 1
153+
self.pcp_rank = 0
154+
142155
self.need_to_return_lse_for_decode = (
143156
self.dcp_world_size > 1 and self.can_return_lse_for_decode
144157
)

vllm/attention/ops/common.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,36 @@ def cp_lse_ag_out_rs(
205205
return out
206206

207207

208+
def cp_lse_ag_out_ar(
209+
cp_attn_out: torch.Tensor,
210+
cp_attn_lse: torch.Tensor,
211+
cp_group: GroupCoordinator,
212+
ctx: CPTritonContext = None,
213+
):
214+
"""
215+
cp_attn_out: [ B, H, D ]
216+
cp_attn_lse: [ B, H ]
217+
"""
218+
if cp_group.world_size == 1:
219+
return cp_attn_out
220+
221+
if ctx is None:
222+
ctx = CPTritonContext()
223+
224+
lses = torch.empty(
225+
(cp_group.world_size,) + cp_attn_lse.shape,
226+
dtype=cp_attn_lse.dtype,
227+
device=cp_attn_lse.device,
228+
)
229+
230+
cp_attn_lse = cp_attn_lse.contiguous()
231+
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
232+
out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
233+
assert out.is_contiguous()
234+
out = cp_group.all_reduce(out)
235+
return out
236+
237+
208238
@triton.jit
209239
def _pack_seq_kernel(
210240
x_ptr, # [N, D]

vllm/config/parallel.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ class ParallelConfig:
7171
"""Number of pipeline parallel groups."""
7272
tensor_parallel_size: int = 1
7373
"""Number of tensor parallel groups."""
74+
prefill_context_parallel_size: int = 1
75+
"""Number of prefill context parallel groups."""
7476
data_parallel_size: int = 1
7577
"""Number of data parallel groups. MoE layers will be sharded according to
7678
the product of the tensor parallel size and data parallel size."""
@@ -467,7 +469,11 @@ def __post_init__(self) -> None:
467469
)
468470

469471
# Continue with the rest of the initialization
470-
self.world_size = self.pipeline_parallel_size * self.tensor_parallel_size
472+
self.world_size = (
473+
self.pipeline_parallel_size
474+
* self.tensor_parallel_size
475+
* self.prefill_context_parallel_size
476+
)
471477

472478
if self.distributed_executor_backend == "external_launcher":
473479
logger.info("Using external launcher for distributed inference.")

vllm/distributed/parallel_state.py

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1085,6 +1085,24 @@ def get_pp_group() -> GroupCoordinator:
10851085
return _PP
10861086

10871087

1088+
_PCP: GroupCoordinator | None = None
1089+
1090+
1091+
def get_pcp_group() -> GroupCoordinator:
1092+
assert _PCP is not None, "prefill context parallel group is not initialized"
1093+
return _PCP
1094+
1095+
1096+
def get_prefill_context_model_parallel_world_size():
1097+
"""Return world size for the tensor model parallel group."""
1098+
return get_pcp_group().world_size
1099+
1100+
1101+
def get_prefill_context_model_parallel_rank():
1102+
"""Return my rank for the tensor model parallel group."""
1103+
return get_pcp_group().rank_in_group
1104+
1105+
10881106
@deprecated(
10891107
"`get_pipeline_model_parallel_group` has been replaced with "
10901108
"`get_pp_group` and may be removed in v0.12. Please use "
@@ -1207,6 +1225,7 @@ def init_distributed_environment(
12071225
def initialize_model_parallel(
12081226
tensor_model_parallel_size: int = 1,
12091227
pipeline_model_parallel_size: int = 1,
1228+
context_model_parallel_size: int = 1,
12101229
decode_context_model_parallel_size: int | None = 1,
12111230
backend: str | None = None,
12121231
) -> None:
@@ -1256,7 +1275,11 @@ def initialize_model_parallel(
12561275
# to get group_ranks for each dimension, transpose that dimension to the
12571276
# last dimension, then reshape to 2D, then unbind the last dimension
12581277
all_ranks = torch.arange(world_size).reshape(
1259-
-1, data_parallel_size, pipeline_model_parallel_size, tensor_model_parallel_size
1278+
-1,
1279+
data_parallel_size,
1280+
pipeline_model_parallel_size,
1281+
context_model_parallel_size,
1282+
tensor_model_parallel_size,
12601283
) # noqa
12611284

12621285
# Build the tensor model-parallel groups.
@@ -1295,7 +1318,7 @@ def initialize_model_parallel(
12951318
global _PP
12961319
assert _PP is None, "pipeline model parallel group is already initialized"
12971320
group_ranks = (
1298-
all_ranks.transpose(2, 3).reshape(-1, pipeline_model_parallel_size).unbind(0)
1321+
all_ranks.transpose(2, 4).reshape(-1, pipeline_model_parallel_size).unbind(0)
12991322
)
13001323
group_ranks = [x.tolist() for x in group_ranks]
13011324
_PP = init_model_parallel_group(
@@ -1304,7 +1327,7 @@ def initialize_model_parallel(
13041327

13051328
global _DP
13061329
assert _DP is None, "data parallel group is already initialized"
1307-
group_ranks = all_ranks.transpose(1, 3).reshape(-1, data_parallel_size).unbind(0)
1330+
group_ranks = all_ranks.transpose(1, 4).reshape(-1, data_parallel_size).unbind(0)
13081331
group_ranks = [x.tolist() for x in group_ranks]
13091332
_DP = init_model_parallel_group(
13101333
group_ranks, get_world_group().local_rank, backend, group_name="dp"
@@ -1314,29 +1337,46 @@ def initialize_model_parallel(
13141337
assert _EP is None, "expert parallel group is already initialized"
13151338
group_ranks = (
13161339
all_ranks.transpose(1, 2)
1317-
.reshape(-1, data_parallel_size * tensor_model_parallel_size)
1340+
.reshape(
1341+
-1,
1342+
data_parallel_size
1343+
* tensor_model_parallel_size
1344+
* context_model_parallel_size,
1345+
)
13181346
.unbind(0)
13191347
)
13201348
group_ranks = [x.tolist() for x in group_ranks]
13211349
_EP = init_model_parallel_group(
13221350
group_ranks, get_world_group().local_rank, backend, group_name="ep"
13231351
)
13241352

1353+
global _PCP
1354+
assert _PCP is None, "prefill context parallel group is already initialized"
1355+
group_ranks = (
1356+
all_ranks.transpose(3, 4).reshape(-1, context_model_parallel_size).unbind(0)
1357+
)
1358+
group_ranks = [x.tolist() for x in group_ranks]
1359+
_PCP = init_model_parallel_group(
1360+
group_ranks, get_world_group().local_rank, backend, group_name="pcp"
1361+
)
1362+
13251363
logger.info(
13261364
"rank %s in world size %s is assigned as "
1327-
"DP rank %s, PP rank %s, TP rank %s, EP rank %s",
1365+
"DP rank %s, PP rank %s, TP rank %s, EP rank %s, PCP rank %s",
13281366
rank,
13291367
world_size,
13301368
_DP.rank_in_group,
13311369
_PP.rank_in_group,
13321370
_TP.rank_in_group,
13331371
_EP.rank_in_group,
1372+
_PCP.rank_in_group,
13341373
)
13351374

13361375

13371376
def ensure_model_parallel_initialized(
13381377
tensor_model_parallel_size: int,
13391378
pipeline_model_parallel_size: int,
1379+
prefill_context_model_parallel_size: int = 1,
13401380
decode_context_model_parallel_size: int | None = 1,
13411381
backend: str | None = None,
13421382
) -> None:
@@ -1349,6 +1389,7 @@ def ensure_model_parallel_initialized(
13491389
initialize_model_parallel(
13501390
tensor_model_parallel_size,
13511391
pipeline_model_parallel_size,
1392+
prefill_context_model_parallel_size,
13521393
decode_context_model_parallel_size,
13531394
backend,
13541395
)
@@ -1365,6 +1406,12 @@ def ensure_model_parallel_initialized(
13651406
f"got: {pp_world_size=} vs. "
13661407
f"wanted: {pipeline_model_parallel_size=}"
13671408
)
1409+
pcp_world_size = get_pcp_group().world_size
1410+
assert pcp_world_size == prefill_context_model_parallel_size, (
1411+
"prefill context parallel group already initialized, but of unexpected size: "
1412+
f"{pcp_world_size=} vs. "
1413+
f"{prefill_context_model_parallel_size=}"
1414+
)
13681415

13691416

13701417
def prepare_communication_buffer_for_model(model: torch.nn.Module):
@@ -1382,6 +1429,8 @@ def prepare_communication_buffer_for_model(model: torch.nn.Module):
13821429
_DP.prepare_communication_buffer_for_model(model)
13831430
if _EP is not None:
13841431
_EP.prepare_communication_buffer_for_model(model)
1432+
if _PCP is not None:
1433+
_PCP.prepare_communication_buffer_for_model(model)
13851434

13861435

13871436
def model_parallel_is_initialized():
@@ -1471,6 +1520,11 @@ def destroy_model_parallel():
14711520
_EP.destroy()
14721521
_EP = None
14731522

1523+
global _PCP
1524+
if _PCP:
1525+
_PCP.destroy()
1526+
_PCP = None
1527+
14741528

14751529
def destroy_distributed_environment():
14761530
global _WORLD, _NODE_COUNT

0 commit comments

Comments
 (0)