File tree Expand file tree Collapse file tree 1 file changed +1
-15
lines changed Expand file tree Collapse file tree 1 file changed +1
-15
lines changed Original file line number Diff line number Diff line change @@ -81,21 +81,7 @@ def parallelize_llama(
8181 parallel_dims .dp_shard_enabled
8282 ): # apply FSDP or HSDP, potentially with Context Parallel
8383
84- # TODO: instead of flattening the mesh twice, we could've done in a batter way:
85- # dp_mesh = world_mesh["dp_cp"] if parallel_dims.cp_enabled else world_mesh["dp"]
86- # However, this leads to an error in `DeviceMesh.__get_item__` which I believe is
87- # a bug in DeviceMesh. We should fix it and then use the above line.
88- dp_mesh_dim_names = (
89- ("dp_replicate" , "dp_shard" )
90- if parallel_dims .dp_replicate_enabled
91- else ("dp" ,)
92- )
93- # note that mesh can only be flattened from the finest-grained mesh dimensions
94- dp_mesh = (
95- world_mesh [(* dp_mesh_dim_names , "cp" )]._flatten ("dp_cp" )
96- if parallel_dims .cp_enabled
97- else world_mesh [dp_mesh_dim_names ]
98- )
84+ dp_mesh = world_mesh ["dp_cp" ] if parallel_dims .cp_enabled else world_mesh ["dp" ]
9985
10086 apply_fsdp (
10187 model ,
You can’t perform that action at this time.
0 commit comments