Skip to content

Commit a69cbf4

Browse files
committed
[BE] replace the extra DeviceMesh _flatten with mesh access
ghstack-source-id: 6afa471 Pull Request resolved: #666
1 parent 53d0f69 commit a69cbf4

File tree

1 file changed

+1
-15
lines changed

1 file changed

+1
-15
lines changed

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -81,21 +81,7 @@ def parallelize_llama(
8181
parallel_dims.dp_shard_enabled
8282
): # apply FSDP or HSDP, potentially with Context Parallel
8383

84-
# TODO: instead of flattening the mesh twice, we could've done in a batter way:
85-
# dp_mesh = world_mesh["dp_cp"] if parallel_dims.cp_enabled else world_mesh["dp"]
86-
# However, this leads to an error in `DeviceMesh.__get_item__` which I believe is
87-
# a bug in DeviceMesh. We should fix it and then use the above line.
88-
dp_mesh_dim_names = (
89-
("dp_replicate", "dp_shard")
90-
if parallel_dims.dp_replicate_enabled
91-
else ("dp",)
92-
)
93-
# note that mesh can only be flattened from the finest-grained mesh dimensions
94-
dp_mesh = (
95-
world_mesh[(*dp_mesh_dim_names, "cp")]._flatten("dp_cp")
96-
if parallel_dims.cp_enabled
97-
else world_mesh[dp_mesh_dim_names]
98-
)
84+
dp_mesh = world_mesh["dp_cp"] if parallel_dims.cp_enabled else world_mesh["dp"]
9985

10086
apply_fsdp(
10187
model,

0 commit comments

Comments
 (0)