Skip to content

Commit 4f729c2

Browse files
authored
Revert "[BE] replace the extra DeviceMesh _flatten with mesh access (#666)"
This reverts commit 3653bf2.
1 parent 3653bf2 commit 4f729c2

File tree

2 files changed

+16
-54
lines changed

2 files changed

+16
-54
lines changed

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
3535
from torchtitan.logging import logger
3636
from torchtitan.parallelisms.parallel_dims import ParallelDims
37-
from torchtitan.parallelisms.utils import check_if_feature_in_pytorch
3837

3938

4039
def parallelize_llama(
@@ -80,31 +79,22 @@ def parallelize_llama(
8079
if (
8180
parallel_dims.dp_shard_enabled
8281
): # apply FSDP or HSDP, potentially with Context Parallel
83-
try:
84-
dp_mesh = (
85-
world_mesh["dp_cp"] if parallel_dims.cp_enabled else world_mesh["dp"]
86-
)
87-
except IndexError:
88-
# note: this is a workaround of the above logic for old pytorch version
89-
# where https://github.com/pytorch/pytorch/pull/138945 is not included
90-
# throw a warning to encourage users to upgrade to a newer pytorch version
91-
check_if_feature_in_pytorch(
92-
"DeviceMesh flattening over 3D+ meshes",
93-
"https://github.com/pytorch/pytorch/pull/138945",
94-
"2.6.0.dev20241030",
95-
)
96-
# TODO: remove this workaround once PyTorch 2.6 is released
97-
dp_mesh_dim_names = (
98-
("dp_replicate", "dp_shard")
99-
if parallel_dims.dp_replicate_enabled
100-
else ("dp",)
101-
)
102-
# note that mesh can only be flattened from the finest-grained mesh dimensions
103-
dp_mesh = (
104-
world_mesh[(*dp_mesh_dim_names, "cp")]._flatten("dp_cp")
105-
if parallel_dims.cp_enabled
106-
else world_mesh[dp_mesh_dim_names]
107-
)
82+
83+
# TODO: instead of flattening the mesh twice, we could've done in a batter way:
84+
# dp_mesh = world_mesh["dp_cp"] if parallel_dims.cp_enabled else world_mesh["dp"]
85+
# However, this leads to an error in `DeviceMesh.__get_item__` which I believe is
86+
# a bug in DeviceMesh. We should fix it and then use the above line.
87+
dp_mesh_dim_names = (
88+
("dp_replicate", "dp_shard")
89+
if parallel_dims.dp_replicate_enabled
90+
else ("dp",)
91+
)
92+
# note that mesh can only be flattened from the finest-grained mesh dimensions
93+
dp_mesh = (
94+
world_mesh[(*dp_mesh_dim_names, "cp")]._flatten("dp_cp")
95+
if parallel_dims.cp_enabled
96+
else world_mesh[dp_mesh_dim_names]
97+
)
10898

10999
apply_fsdp(
110100
model,

torchtitan/parallelisms/utils.py

Lines changed: 0 additions & 28 deletions
This file was deleted.

0 commit comments

Comments
 (0)