Skip to content

Commit d7cabfb

Browse files
authored
[BE] replace the extra DeviceMesh _flatten with mesh access (#667)
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * __->__ #667 Note: This PR is a reland of #666 where the PR was mistakenly merged into a wrong branch. **Summary** pytorch/pytorch#138945 fixes DeviceMesh access on flattened mesh which are constructed from more than 2 meshes. Refer to the fix PR for details if interested. In #592 we avoided this issue by calling `_flatten` instead of direct accessing the flattened mesh. We want to turn back to mesh access which is more straightforward since the fix has been merged in PyTorch.
1 parent e1fbced commit d7cabfb

File tree

2 files changed

+54
-16
lines changed

2 files changed

+54
-16
lines changed

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
3535
from torchtitan.logging import logger
3636
from torchtitan.parallelisms.parallel_dims import ParallelDims
37+
from torchtitan.parallelisms.utils import check_if_feature_in_pytorch
3738

3839

3940
def parallelize_llama(
@@ -79,22 +80,31 @@ def parallelize_llama(
7980
if (
8081
parallel_dims.dp_shard_enabled
8182
): # apply FSDP or HSDP, potentially with Context Parallel
82-
83-
# TODO: instead of flattening the mesh twice, we could've done in a batter way:
84-
# dp_mesh = world_mesh["dp_cp"] if parallel_dims.cp_enabled else world_mesh["dp"]
85-
# However, this leads to an error in `DeviceMesh.__get_item__` which I believe is
86-
# a bug in DeviceMesh. We should fix it and then use the above line.
87-
dp_mesh_dim_names = (
88-
("dp_replicate", "dp_shard")
89-
if parallel_dims.dp_replicate_enabled
90-
else ("dp",)
91-
)
92-
# note that mesh can only be flattened from the finest-grained mesh dimensions
93-
dp_mesh = (
94-
world_mesh[(*dp_mesh_dim_names, "cp")]._flatten("dp_cp")
95-
if parallel_dims.cp_enabled
96-
else world_mesh[dp_mesh_dim_names]
97-
)
83+
try:
84+
dp_mesh = (
85+
world_mesh["dp_cp"] if parallel_dims.cp_enabled else world_mesh["dp"]
86+
)
87+
except IndexError:
88+
# note: this is a workaround of the above logic for old pytorch version
89+
# where https://github.com/pytorch/pytorch/pull/138945 is not included
90+
# throw a warning to encourage users to upgrade to a newer pytorch version
91+
check_if_feature_in_pytorch(
92+
"DeviceMesh flattening over 3D+ meshes",
93+
"https://github.com/pytorch/pytorch/pull/138945",
94+
"2.6.0.dev20241030",
95+
)
96+
# TODO: remove this workaround once PyTorch 2.6 is released
97+
dp_mesh_dim_names = (
98+
("dp_replicate", "dp_shard")
99+
if parallel_dims.dp_replicate_enabled
100+
else ("dp",)
101+
)
102+
# note that mesh can only be flattened from the finest-grained mesh dimensions
103+
dp_mesh = (
104+
world_mesh[(*dp_mesh_dim_names, "cp")]._flatten("dp_cp")
105+
if parallel_dims.cp_enabled
106+
else world_mesh[dp_mesh_dim_names]
107+
)
98108

99109
apply_fsdp(
100110
model,

torchtitan/parallelisms/utils.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import Optional
7+
8+
import torch
9+
from torchtitan.logging import logger
10+
11+
12+
def check_if_feature_in_pytorch(
13+
feature_name: str,
14+
pull_request: str,
15+
min_nightly_version: Optional[str] = None,
16+
) -> None:
17+
if "git" in torch.__version__: # pytorch is built from source
18+
# notify users to check if the pull request is included in their pytorch
19+
logger.warning(
20+
"detected that the pytorch is built from source. Please make sure the PR "
21+
f"({pull_request_link}) is included in pytorch for correct {feature_name}."
22+
)
23+
elif min_nightly_version is not None and torch.__version__ < min_nightly_version:
24+
logger.warning(
25+
f"detected that the pytorch version {torch.__version__} is older than "
26+
f"{min_nightly_version}. Please upgrade a newer version to include the "
27+
f"change in ({pull_request_link}) for correct {feature_name}."
28+
)

0 commit comments

Comments
 (0)