Skip to content

Conversation

@fmassa
Copy link
Contributor

@fmassa fmassa commented Jul 2, 2025

PyTorch currently decomposes any 3d-input nn.Linear (and matmul) into a sequence of view -> mm -> view operations.

This has as a consequence of breaking any type of sharding on both the batch and the sequence dimension, because the flattening that happens doesn't allow to preserve this sharding.

While we wait for PyTorch to avoid decomposing nn.Linear, we instead take the route of pattern-matching the nn.Linear specific occurences, and we replace them with an einsum operator.

We perform this pattern-matching replacement for both the forward as well as the backward pass.

For now, the pass is disabled by default, and can be enabled via a global flag. I'm leaving it disabled for now by default because this actually requires changing some other things like improving the cost model as in #94, so I'm keeping the behavior the same for now while I experiment with the other things more easily

This tries to support CP-style sharding, by overcoming a limitation of DTensor. Doesn't yet work as _mm_strategy is failing
@fmassa fmassa requested review from bdhirsh and wconstab July 2, 2025 12:30
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jul 2, 2025

@register_opschema_rule(torch.ops.aten.matmul.default)
def matmul_rule(mesh, op_schema):
# from torch.distributed.tensor._ops._einsum_strategy import gen_einsum_strategies
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i would have thought to use the einsum strategies here. for my education, what is the difference between einsum and mm_like in this context? cc @XilunWu

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll have to end-up using the einsum strategies, because mm_strategies fail :-)

The difference I believe is just that the mm_strategy filter out invalid strategies already, while einsum strategies don't as it doesn't know the size of the tensors, only the number of dimensions

This removes a long-standing hack to tell the solver that S(1) -> R is more expensive than S(0) -> R because of an additional data movement
@fmassa fmassa changed the base branch from main to fmassa/compute_cost_in_comms August 21, 2025 09:33
…h/autoparallel into fmassa/replace_view_mm_view
fmassa added a commit that referenced this pull request Aug 21, 2025
Previously, if we had tuple of tensors as an argument to a function, we wouldn't apply any sharding on it. This is split from #26 , where I originally found this issue
fmassa added a commit that referenced this pull request Aug 21, 2025
* Support tuple of tensors in estimate_strategy_runtime_cost

Previously, if we had tuple of tensors as an argument to a function, we wouldn't apply any sharding on it. This is split from #26 , where I originally found this issue

* Fix bad copy-paste
@fmassa fmassa changed the title [WIP] Replace view -> mm -> view with matmul Avoid nn.Linear decomposition by replacing view -> mm -> view with einsum Aug 27, 2025
@fmassa fmassa marked this pull request as ready for review August 27, 2025 13:13
@fmassa fmassa changed the base branch from fmassa/compute_cost_in_comms to main August 27, 2025 13:20
@fmassa fmassa merged commit c680107 into main Aug 27, 2025
6 checks passed
@fmassa fmassa deleted the fmassa/replace_view_mm_view branch August 27, 2025 13:44
wconstab added a commit that referenced this pull request Aug 27, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants