-
Notifications
You must be signed in to change notification settings - Fork 615
[BE] replace the extra DeviceMesh _flatten with mesh access #666
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
| if parallel_dims.cp_enabled | ||
| else world_mesh[dp_mesh_dim_names] | ||
| ) | ||
| dp_mesh = world_mesh["dp_cp"] if parallel_dims.cp_enabled else world_mesh["dp"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a new DeviceMesh functionality that reacts specifically to <name1>_<name2>?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not new. DeviceMesh supports world_mesh[<name1>_<name2>] when the _flatten behavior was implemented. However, it has a bug -- if the flattened mesh is constructed from 3+ mesh dimensions (e.g. dp_cp is flattened using dp_shard, dp_replicate, and cp. Accessing world_mesh[dp_cp] throws error which breaks 3D/4D/5D composability).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we catch the error and ask users to update to some version?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For my understanding, for dp, if hsdp is enabled, "dp" is the flatten mesh for "dp_replicate", "dp_shard", right? Otherwise, "dp" is just "dp_shard".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wz337 , that's right. To summarize:
- FSDP: the only dp dimension in mesh is "dp"
- DDP: the only dp dimension in mesh is "dp"
- HSDP: the basic dp dimensions in mesh are "dp_shard" and "dp_replicate", which are later on flattened into "dp"
tianyu-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
| if parallel_dims.cp_enabled | ||
| else world_mesh[dp_mesh_dim_names] | ||
| ) | ||
| dp_mesh = world_mesh["dp_cp"] if parallel_dims.cp_enabled else world_mesh["dp"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fegin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's better to have a try-except to indicate users are not using the latest PyTorch.
| if parallel_dims.cp_enabled | ||
| else world_mesh[dp_mesh_dim_names] | ||
| ) | ||
| dp_mesh = world_mesh["dp_cp"] if parallel_dims.cp_enabled else world_mesh["dp"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we catch the error and ask users to update to some version?
Oh yeah that's right... |
**Summary** pytorch/pytorch#138945 fixes DeviceMesh access on flattened mesh which are constructed from more than 2 meshes. Refer to the fix PR for details if interested. In #592 we avoided this issue by calling `_flatten` instead of direct accessing the flattened mesh. We want to turn back to mesh access which is more straightforward since the fix has been merged in PyTorch. [ghstack-poisoned]
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * __->__ #667 Note: This PR is a reland of #666 where the PR was mistakenly merged into a wrong branch. **Summary** pytorch/pytorch#138945 fixes DeviceMesh access on flattened mesh which are constructed from more than 2 meshes. Refer to the fix PR for details if interested. In #592 we avoided this issue by calling `_flatten` instead of direct accessing the flattened mesh. We want to turn back to mesh access which is more straightforward since the fix has been merged in PyTorch.
) Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * __->__ pytorch#667 Note: This PR is a reland of pytorch#666 where the PR was mistakenly merged into a wrong branch. **Summary** pytorch/pytorch#138945 fixes DeviceMesh access on flattened mesh which are constructed from more than 2 meshes. Refer to the fix PR for details if interested. In pytorch#592 we avoided this issue by calling `_flatten` instead of direct accessing the flattened mesh. We want to turn back to mesh access which is more straightforward since the fix has been merged in PyTorch.
Stack from ghstack (oldest at bottom):
Summary
pytorch/pytorch#138945 fixes DeviceMesh access on flattened mesh which are constructed from more than 2 meshes. Refer to the fix PR for details if interested.
In #592 we avoided this issue by calling
_flatteninstead of direct accessing the flattened mesh. We want to turn back to mesh access which is more straightforward since the fix has been merged in PyTorch.