1515class ParallelDims :
1616 dp_replicate : int
1717 dp_shard : int
18+ cp : int
1819 tp : int
1920 pp : int
2021 world_size : int
@@ -24,36 +25,38 @@ def __post_init__(self):
2425 self ._validate ()
2526
2627 def _validate (self ):
27- dp_replicate , dp_shard , tp , pp = (
28+ dp_replicate , dp_shard , cp , tp , pp = (
2829 self .dp_replicate ,
2930 self .dp_shard ,
31+ self .cp ,
3032 self .tp ,
3133 self .pp ,
3234 )
33- for d in (dp_replicate , tp , pp ):
35+ for d in (dp_replicate , cp , tp , pp ):
3436 assert d >= 1 , "Parallelism degree should be >= 1, except for dp_shard"
3537 assert dp_shard == - 1 or dp_shard >= 1 , " dp_shard must -1 or >=1."
3638
3739 dp = dp_replicate * dp_shard
3840 if dp < 0 :
39- dp = self .world_size // (tp * pp )
41+ dp = self .world_size // (cp * tp * pp )
4042 self .dp_shard = dp_shard = dp // dp_replicate
4143
4244 assert dp_replicate >= 1
4345 assert dp_shard >= 1
46+ assert cp >= 1 , cp
4447 assert tp >= 1 , tp
4548 assert pp >= 1 , pp
46- assert dp_replicate * dp_shard * tp * pp == self .world_size , (
49+ assert dp_replicate * dp_shard * cp * tp * pp == self .world_size , (
4750 f"Invalid parallel dims: dp_replicate({ dp_replicate } ) * dp_shard({ dp_shard } ) * "
48- f"tp({ tp } ) * pp({ pp } ) != WORLD_SIZE({ self .world_size } )"
51+ f"cp( { cp } ) * tp({ tp } ) * pp({ pp } ) != WORLD_SIZE({ self .world_size } )"
4952 )
5053
5154 def build_mesh (self , device_type ):
5255 dims = []
5356 names = []
5457 for d , name in zip (
55- [self .pp , self .dp_replicate , self .dp_shard , self .tp ],
56- ["pp" , "dp_replicate" , "dp_shard" , "tp" ],
58+ [self .pp , self .dp_replicate , self .dp_shard , self .cp , self . tp ],
59+ ["pp" , "dp_replicate" , "dp_shard" , "cp" , " tp" ],
5760 ):
5861 if d > 1 :
5962 dims .append (d )
@@ -71,6 +74,13 @@ def build_mesh(self, device_type):
7174 # initialized
7275 if self .dp_replicate > 1 and self .dp_shard > 1 :
7376 mesh ["dp_replicate" , "dp_shard" ]._flatten (mesh_dim_name = "dp" )
77+
78+ if self .cp > 1 :
79+ if self .dp_replicate > 1 and self .dp_shard > 1 :
80+ mesh ["dp_replicate" , "dp_shard" , "cp" ]._flatten (mesh_dim_name = "dp_cp" )
81+ else :
82+ mesh ["dp" , "cp" ]._flatten (mesh_dim_name = "dp_cp" )
83+
7484 return mesh
7585
7686 @property
@@ -85,6 +95,10 @@ def dp_replicate_enabled(self):
8595 def dp_shard_enabled (self ):
8696 return self .dp_shard > 1
8797
98+ @property
99+ def cp_enabled (self ):
100+ return self .cp > 1
101+
88102 @property
89103 def tp_enabled (self ):
90104 return self .tp > 1
@@ -98,5 +112,5 @@ def loss_parallel_enabled(self):
98112 return self .tp > 1 and self .enable_loss_parallel
99113
100114 @cached_property
101- def model_parallel_size (self ):
102- return self .tp * self .pp
115+ def non_data_parallel_size (self ):
116+ return self .cp * self . tp * self .pp
0 commit comments