Skip to content

Conversation

@XilunWu
Copy link
Contributor

@XilunWu XilunWu commented Oct 30, 2024

Stack from ghstack (oldest at bottom):

Summary
pytorch/pytorch#138945 fixes DeviceMesh access on flattened mesh which are constructed from more than 2 meshes. Refer to the fix PR for details if interested.

In #592 we avoided this issue by calling _flatten instead of direct accessing the flattened mesh. We want to turn back to mesh access which is more straightforward since the fix has been merged in PyTorch.

XilunWu added a commit that referenced this pull request Oct 30, 2024
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 30, 2024
if parallel_dims.cp_enabled
else world_mesh[dp_mesh_dim_names]
)
dp_mesh = world_mesh["dp_cp"] if parallel_dims.cp_enabled else world_mesh["dp"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a new DeviceMesh functionality that reacts specifically to <name1>_<name2>?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not new. DeviceMesh supports world_mesh[<name1>_<name2>] when the _flatten behavior was implemented. However, it has a bug -- if the flattened mesh is constructed from 3+ mesh dimensions (e.g. dp_cp is flattened using dp_shard, dp_replicate, and cp. Accessing world_mesh[dp_cp] throws error which breaks 3D/4D/5D composability).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we catch the error and ask users to update to some version?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my understanding, for dp, if hsdp is enabled, "dp" is the flatten mesh for "dp_replicate", "dp_shard", right? Otherwise, "dp" is just "dp_shard".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wz337 , that's right. To summarize:

  1. FSDP: the only dp dimension in mesh is "dp"
  2. DDP: the only dp dimension in mesh is "dp"
  3. HSDP: the basic dp dimensions in mesh are "dp_shard" and "dp_replicate", which are later on flattened into "dp"

@XilunWu XilunWu requested review from fegin, tianyu-l and wz337 October 30, 2024 22:10
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

if parallel_dims.cp_enabled
else world_mesh[dp_mesh_dim_names]
)
dp_mesh = world_mesh["dp_cp"] if parallel_dims.cp_enabled else world_mesh["dp"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to have a try-except to indicate users are not using the latest PyTorch.

if parallel_dims.cp_enabled
else world_mesh[dp_mesh_dim_names]
)
dp_mesh = world_mesh["dp_cp"] if parallel_dims.cp_enabled else world_mesh["dp"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we catch the error and ask users to update to some version?

@XilunWu
Copy link
Contributor Author

XilunWu commented Oct 30, 2024

It's better to have a try-except to indicate users are not using the latest PyTorch.

Oh yeah that's right...

**Summary**
pytorch/pytorch#138945 fixes DeviceMesh access on flattened mesh which are constructed from more than 2 meshes. Refer to the fix PR for details if interested.

In #592 we avoided this issue by calling `_flatten` instead of direct accessing the flattened mesh. We want to turn back to mesh access which is more straightforward since the fix has been merged in PyTorch.


[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Oct 31, 2024
@XilunWu XilunWu merged commit 3653bf2 into gh/XilunWu/9/base Oct 31, 2024
5 checks passed
XilunWu added a commit that referenced this pull request Oct 31, 2024
XilunWu added a commit that referenced this pull request Oct 31, 2024
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at
bottom):
* __->__ #667

Note: This PR is a reland of #666 where the PR was mistakenly merged
into a wrong branch.

**Summary**
pytorch/pytorch#138945 fixes DeviceMesh access
on flattened mesh which are constructed from more than 2 meshes. Refer
to the fix PR for details if interested.

In #592 we avoided this issue by calling `_flatten` instead of direct
accessing the flattened mesh. We want to turn back to mesh access which
is more straightforward since the fix has been merged in PyTorch.
mori360 pushed a commit to mori360/torchtitan that referenced this pull request Nov 26, 2024
)

Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at
bottom):
* __->__ pytorch#667

Note: This PR is a reland of pytorch#666 where the PR was mistakenly merged
into a wrong branch.

**Summary**
pytorch/pytorch#138945 fixes DeviceMesh access
on flattened mesh which are constructed from more than 2 meshes. Refer
to the fix PR for details if interested.

In pytorch#592 we avoided this issue by calling `_flatten` instead of direct
accessing the flattened mesh. We want to turn back to mesh access which
is more straightforward since the fix has been merged in PyTorch.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants