@@ -1085,6 +1085,24 @@ def get_pp_group() -> GroupCoordinator:
10851085 return _PP
10861086
10871087
1088+ _PCP : GroupCoordinator | None = None
1089+
1090+
1091+ def get_pcp_group () -> GroupCoordinator :
1092+ assert _PCP is not None , "prefill context parallel group is not initialized"
1093+ return _PCP
1094+
1095+
1096+ def get_prefill_context_model_parallel_world_size ():
1097+ """Return world size for the tensor model parallel group."""
1098+ return get_pcp_group ().world_size
1099+
1100+
1101+ def get_prefill_context_model_parallel_rank ():
1102+ """Return my rank for the tensor model parallel group."""
1103+ return get_pcp_group ().rank_in_group
1104+
1105+
10881106@deprecated (
10891107 "`get_pipeline_model_parallel_group` has been replaced with "
10901108 "`get_pp_group` and may be removed in v0.12. Please use "
@@ -1207,6 +1225,7 @@ def init_distributed_environment(
12071225def initialize_model_parallel (
12081226 tensor_model_parallel_size : int = 1 ,
12091227 pipeline_model_parallel_size : int = 1 ,
1228+ context_model_parallel_size : int = 1 ,
12101229 decode_context_model_parallel_size : int | None = 1 ,
12111230 backend : str | None = None ,
12121231) -> None :
@@ -1256,7 +1275,11 @@ def initialize_model_parallel(
12561275 # to get group_ranks for each dimension, transpose that dimension to the
12571276 # last dimension, then reshape to 2D, then unbind the last dimension
12581277 all_ranks = torch .arange (world_size ).reshape (
1259- - 1 , data_parallel_size , pipeline_model_parallel_size , tensor_model_parallel_size
1278+ - 1 ,
1279+ data_parallel_size ,
1280+ pipeline_model_parallel_size ,
1281+ context_model_parallel_size ,
1282+ tensor_model_parallel_size ,
12601283 ) # noqa
12611284
12621285 # Build the tensor model-parallel groups.
@@ -1295,7 +1318,7 @@ def initialize_model_parallel(
12951318 global _PP
12961319 assert _PP is None , "pipeline model parallel group is already initialized"
12971320 group_ranks = (
1298- all_ranks .transpose (2 , 3 ).reshape (- 1 , pipeline_model_parallel_size ).unbind (0 )
1321+ all_ranks .transpose (2 , 4 ).reshape (- 1 , pipeline_model_parallel_size ).unbind (0 )
12991322 )
13001323 group_ranks = [x .tolist () for x in group_ranks ]
13011324 _PP = init_model_parallel_group (
@@ -1304,7 +1327,7 @@ def initialize_model_parallel(
13041327
13051328 global _DP
13061329 assert _DP is None , "data parallel group is already initialized"
1307- group_ranks = all_ranks .transpose (1 , 3 ).reshape (- 1 , data_parallel_size ).unbind (0 )
1330+ group_ranks = all_ranks .transpose (1 , 4 ).reshape (- 1 , data_parallel_size ).unbind (0 )
13081331 group_ranks = [x .tolist () for x in group_ranks ]
13091332 _DP = init_model_parallel_group (
13101333 group_ranks , get_world_group ().local_rank , backend , group_name = "dp"
@@ -1314,29 +1337,46 @@ def initialize_model_parallel(
13141337 assert _EP is None , "expert parallel group is already initialized"
13151338 group_ranks = (
13161339 all_ranks .transpose (1 , 2 )
1317- .reshape (- 1 , data_parallel_size * tensor_model_parallel_size )
1340+ .reshape (
1341+ - 1 ,
1342+ data_parallel_size
1343+ * tensor_model_parallel_size
1344+ * context_model_parallel_size ,
1345+ )
13181346 .unbind (0 )
13191347 )
13201348 group_ranks = [x .tolist () for x in group_ranks ]
13211349 _EP = init_model_parallel_group (
13221350 group_ranks , get_world_group ().local_rank , backend , group_name = "ep"
13231351 )
13241352
1353+ global _PCP
1354+ assert _PCP is None , "prefill context parallel group is already initialized"
1355+ group_ranks = (
1356+ all_ranks .transpose (3 , 4 ).reshape (- 1 , context_model_parallel_size ).unbind (0 )
1357+ )
1358+ group_ranks = [x .tolist () for x in group_ranks ]
1359+ _PCP = init_model_parallel_group (
1360+ group_ranks , get_world_group ().local_rank , backend , group_name = "pcp"
1361+ )
1362+
13251363 logger .info (
13261364 "rank %s in world size %s is assigned as "
1327- "DP rank %s, PP rank %s, TP rank %s, EP rank %s" ,
1365+ "DP rank %s, PP rank %s, TP rank %s, EP rank %s, PCP rank %s " ,
13281366 rank ,
13291367 world_size ,
13301368 _DP .rank_in_group ,
13311369 _PP .rank_in_group ,
13321370 _TP .rank_in_group ,
13331371 _EP .rank_in_group ,
1372+ _PCP .rank_in_group ,
13341373 )
13351374
13361375
13371376def ensure_model_parallel_initialized (
13381377 tensor_model_parallel_size : int ,
13391378 pipeline_model_parallel_size : int ,
1379+ prefill_context_model_parallel_size : int = 1 ,
13401380 decode_context_model_parallel_size : int | None = 1 ,
13411381 backend : str | None = None ,
13421382) -> None :
@@ -1349,6 +1389,7 @@ def ensure_model_parallel_initialized(
13491389 initialize_model_parallel (
13501390 tensor_model_parallel_size ,
13511391 pipeline_model_parallel_size ,
1392+ prefill_context_model_parallel_size ,
13521393 decode_context_model_parallel_size ,
13531394 backend ,
13541395 )
@@ -1365,6 +1406,12 @@ def ensure_model_parallel_initialized(
13651406 f"got: { pp_world_size = } vs. "
13661407 f"wanted: { pipeline_model_parallel_size = } "
13671408 )
1409+ pcp_world_size = get_pcp_group ().world_size
1410+ assert pcp_world_size == prefill_context_model_parallel_size , (
1411+ "prefill context parallel group already initialized, but of unexpected size: "
1412+ f"{ pcp_world_size = } vs. "
1413+ f"{ prefill_context_model_parallel_size = } "
1414+ )
13681415
13691416
13701417def prepare_communication_buffer_for_model (model : torch .nn .Module ):
@@ -1382,6 +1429,8 @@ def prepare_communication_buffer_for_model(model: torch.nn.Module):
13821429 _DP .prepare_communication_buffer_for_model (model )
13831430 if _EP is not None :
13841431 _EP .prepare_communication_buffer_for_model (model )
1432+ if _PCP is not None :
1433+ _PCP .prepare_communication_buffer_for_model (model )
13851434
13861435
13871436def model_parallel_is_initialized ():
@@ -1471,6 +1520,11 @@ def destroy_model_parallel():
14711520 _EP .destroy ()
14721521 _EP = None
14731522
1523+ global _PCP
1524+ if _PCP :
1525+ _PCP .destroy ()
1526+ _PCP = None
1527+
14741528
14751529def destroy_distributed_environment ():
14761530 global _WORLD , _NODE_COUNT
0 commit comments