Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BE] replace the extra DeviceMesh _flatten with mesh access #667

Merged
merged 2 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 26 additions & 16 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.logging import logger
from torchtitan.parallelisms.parallel_dims import ParallelDims
from torchtitan.parallelisms.utils import check_if_feature_in_pytorch


def parallelize_llama(
Expand Down Expand Up @@ -79,22 +80,31 @@ def parallelize_llama(
if (
parallel_dims.dp_shard_enabled
): # apply FSDP or HSDP, potentially with Context Parallel

# TODO: instead of flattening the mesh twice, we could've done in a batter way:
# dp_mesh = world_mesh["dp_cp"] if parallel_dims.cp_enabled else world_mesh["dp"]
# However, this leads to an error in `DeviceMesh.__get_item__` which I believe is
# a bug in DeviceMesh. We should fix it and then use the above line.
dp_mesh_dim_names = (
("dp_replicate", "dp_shard")
if parallel_dims.dp_replicate_enabled
else ("dp",)
)
# note that mesh can only be flattened from the finest-grained mesh dimensions
dp_mesh = (
world_mesh[(*dp_mesh_dim_names, "cp")]._flatten("dp_cp")
if parallel_dims.cp_enabled
else world_mesh[dp_mesh_dim_names]
)
try:
dp_mesh = (
world_mesh["dp_cp"] if parallel_dims.cp_enabled else world_mesh["dp"]
)
except IndexError:
# note: this is a workaround of the above logic for old pytorch version
# where https://github.com/pytorch/pytorch/pull/138945 is not included
# throw a warning to encourage users to upgrade to a newer pytorch version
check_if_feature_in_pytorch(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fegin
Since we are already suggesting user to use latest PyTorch nightly, I wonder how necessary it is to include such messages, as multiple features would depend on non-stable release. E.g. for PP, we constantly see new features, without such try-catches.

Not saying this is not good, but there seems to be a trade-off / sacrifice for dev velocity and simplicity.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm okay to remove it. Or we can periodically cleanup (like weekly). In general, this is not just for the end users but also for developers like us who may not up the local PyTorch reguarly.

"DeviceMesh flattening over 3D+ meshes",
"https://github.com/pytorch/pytorch/pull/138945",
"2.6.0.dev20241030",
)
# TODO: remove this workaround once PyTorch 2.6 is released
dp_mesh_dim_names = (
("dp_replicate", "dp_shard")
if parallel_dims.dp_replicate_enabled
else ("dp",)
)
# note that mesh can only be flattened from the finest-grained mesh dimensions
dp_mesh = (
world_mesh[(*dp_mesh_dim_names, "cp")]._flatten("dp_cp")
XilunWu marked this conversation as resolved.
Show resolved Hide resolved
if parallel_dims.cp_enabled
else world_mesh[dp_mesh_dim_names]
)

apply_fsdp(
model,
Expand Down
28 changes: 28 additions & 0 deletions torchtitan/parallelisms/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional

import torch
from torchtitan.logging import logger


def check_if_feature_in_pytorch(
feature_name: str,
pull_request: str,
min_nightly_version: Optional[str] = None,
) -> None:
if "git" in torch.__version__: # pytorch is built from source
# notify users to check if the pull request is included in their pytorch
logger.warning(
"detected that the pytorch is built from source. Please make sure the PR "
f"({pull_request_link}) is included in pytorch for correct {feature_name}."
)
elif min_nightly_version is not None and torch.__version__ < min_nightly_version:
logger.warning(
f"detected that the pytorch version {torch.__version__} is older than "
f"{min_nightly_version}. Please upgrade a newer version to include the "
f"change in ({pull_request_link}) for correct {feature_name}."
)
Loading