-
Notifications
You must be signed in to change notification settings - Fork 8
Avoid nn.Linear decomposition by replacing view -> mm -> view with einsum #26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This tries to support CP-style sharding, by overcoming a limitation of DTensor. Doesn't yet work as _mm_strategy is failing
autoparallel/propagation_rules.py
Outdated
|
|
||
| @register_opschema_rule(torch.ops.aten.matmul.default) | ||
| def matmul_rule(mesh, op_schema): | ||
| # from torch.distributed.tensor._ops._einsum_strategy import gen_einsum_strategies |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i would have thought to use the einsum strategies here. for my education, what is the difference between einsum and mm_like in this context? cc @XilunWu
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll have to end-up using the einsum strategies, because mm_strategies fail :-)
The difference I believe is just that the mm_strategy filter out invalid strategies already, while einsum strategies don't as it doesn't know the size of the tensors, only the number of dimensions
…sa/replace_view_mm_view
Somethings are starting to work, but we are not yet there
…sa/replace_view_mm_view
…sa/replace_view_mm_view
Before this, if we had a list of tensors we wouldn't shard the tensors inside the list
This removes a long-standing hack to tell the solver that S(1) -> R is more expensive than S(0) -> R because of an additional data movement
This removes a long-standing hack to tell the solver that S(1) -> R is more expensive than S(0) -> R because of an additional data movement
…sa/compute_cost_in_comms
…sa/compute_cost_in_comms
…h/autoparallel into fmassa/replace_view_mm_view
Previously, if we had tuple of tensors as an argument to a function, we wouldn't apply any sharding on it. This is split from #26 , where I originally found this issue
* Support tuple of tensors in estimate_strategy_runtime_cost Previously, if we had tuple of tensors as an argument to a function, we wouldn't apply any sharding on it. This is split from #26 , where I originally found this issue * Fix bad copy-paste
…sa/compute_cost_in_comms
…h/autoparallel into fmassa/replace_view_mm_view
…sa/compute_cost_in_comms
…h/autoparallel into fmassa/replace_view_mm_view
PyTorch currently decomposes any 3d-input
nn.Linear(and matmul) into a sequence ofview -> mm -> viewoperations.This has as a consequence of breaking any type of sharding on both the batch and the sequence dimension, because the flattening that happens doesn't allow to preserve this sharding.
While we wait for PyTorch to avoid decomposing
nn.Linear, we instead take the route of pattern-matching thenn.Linearspecific occurences, and we replace them with aneinsumoperator.We perform this pattern-matching replacement for both the forward as well as the backward pass.
For now, the pass is disabled by default, and can be enabled via a global flag. I'm leaving it disabled for now by default because this actually requires changing some other things like improving the cost model as in #94, so I'm keeping the behavior the same for now while I experiment with the other things more easily