Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
d65d06e
[WIP] Replace view -> mm -> view with matmul
fmassa Jul 2, 2025
d7398c0
Merge branch 'main' of github.com:pytorch-labs/autoparallel into fmas…
fmassa Jul 3, 2025
e7f2003
Fix matmul propagation rule
fmassa Jul 3, 2025
7cc87e1
Merge branch 'main' of github.com:pytorch-labs/autoparallel into fmas…
fmassa Jul 30, 2025
eadc2c3
Merge branch 'main' of github.com:meta-pytorch/autoparallel into fmas…
fmassa Aug 12, 2025
d7dabee
Move function to graph_utils.py
fmassa Aug 12, 2025
99abc4e
Pull improvements from https://github.com/meta-pytorch/autoparallel/p…
fmassa Aug 12, 2025
48ae195
Fix equation for einsum
fmassa Aug 12, 2025
2a2a4d8
Cleanup code now that PyTorch has fixed _gen_einsum_strategies
fmassa Aug 12, 2025
b5a5098
Generalize to more than 3d
fmassa Aug 12, 2025
840690b
Generalize backward pass as well and make everything call into einsum
fmassa Aug 12, 2025
1d443b8
Add note about future work
fmassa Aug 12, 2025
d6c8ae0
Add einsum flops and generalize creation of sharded tensors
fmassa Aug 12, 2025
12155e2
Disable erroneous sdpa rule from backward
fmassa Aug 12, 2025
25fdd8e
Account for compute cost in collectives as well
fmassa Aug 13, 2025
d1281a4
Account for compute cost in collectives as well
fmassa Aug 13, 2025
299f184
Merge branch 'main' of github.com:meta-pytorch/autoparallel into fmas…
fmassa Aug 14, 2025
a56c784
Support getitem as well
fmassa Aug 14, 2025
851cf00
Improve comments and suppose 80% efficiency
fmassa Aug 16, 2025
8396a09
Merge branch 'main' of github.com:meta-pytorch/autoparallel into fmas…
fmassa Aug 20, 2025
5f4f730
Suppose 70% efficiency for comms
fmassa Aug 20, 2025
2e46457
Merge branch 'fmassa/compute_cost_in_comms' of github.com:meta-pytorc…
fmassa Aug 21, 2025
a8f435c
Merge branch 'main' of github.com:meta-pytorch/autoparallel into fmas…
fmassa Aug 21, 2025
b4ae76d
Merge branch 'fmassa/compute_cost_in_comms' of github.com:meta-pytorc…
fmassa Aug 21, 2025
e025188
Merge branch 'main' of github.com:meta-pytorch/autoparallel into fmas…
fmassa Aug 22, 2025
10219c9
Merge branch 'fmassa/compute_cost_in_comms' of github.com:meta-pytorc…
fmassa Aug 22, 2025
e3c5e9f
Add comment and set it to false by default
fmassa Aug 27, 2025
4b43944
Revert changes from another PR
fmassa Aug 27, 2025
5d434fd
Add spaces back
fmassa Aug 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions autoparallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .cast_parametrization import apply_dtype_cast, canonicalize_mp, set_dtype_cast
from .graph_utils import (
_add_alias,
_replace_view_mm_view_with_einsum,
assert_has_no_collectives,
cleanup_graph,
update_joint_with_descriptors,
Expand All @@ -37,6 +38,8 @@
from .optimize_sharding import ShardingOptimizer
from .utils import _get_device_from_mesh

_APPLY_VIEW_MM_VIEW_PATTERN = False


def try_convert_fake_to_real(tensors):
out = {}
Expand Down Expand Up @@ -230,6 +233,8 @@ def build_model_graph(self):
assert_has_no_collectives(gm)

cleanup_graph(gm)
if _APPLY_VIEW_MM_VIEW_PATTERN:
_replace_view_mm_view_with_einsum(gm)
# now add aliases nodes to the graph to
# give more room for optimizations
_add_alias(gm)
Expand Down
29 changes: 28 additions & 1 deletion autoparallel/compute_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,34 @@

import torch
from torch.utils._pytree import tree_flatten, tree_map_only
from torch.utils.flop_counter import FlopCounterMode
from torch.utils.flop_counter import FlopCounterMode, register_flop_formula


@register_flop_formula(torch.ops.aten.einsum, get_raw=True)
def einsum_flop(equation, tensors, out=None, **kwargs) -> int:
# from torch.distributed.tensor._ops._einsum_strategy import EinsumDims
assert len(tensors) == 2
a_shape, b_shape = [x.shape for x in tensors]

# parse einop equation and extract dims
# TODO: generalize
# input_dims, output_dim = EinsumDims.parse_equation(equation)
# edims = EinsumDims.parse_dims(input_dims, output_dim)

if len(a_shape) == 3 and len(b_shape) == 3:
b, m, k = a_shape
b1, n, k2 = b_shape
assert b == b1
assert m == n
flop = (b * m) * k * k2 * 2
elif len(a_shape) == 3 and len(b_shape) == 2:
b, m, k = a_shape
k2, n = b_shape
assert k == k2
flop = b * m * n * k * 2
else:
raise NotImplementedError(f"Unsupported einsum shapes: {a_shape} {b_shape}")
return flop


@dataclass
Expand Down
78 changes: 78 additions & 0 deletions autoparallel/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,81 @@ def assert_has_no_collectives(gm: torch.fx.GraphModule):
f"autoparallel.local_map_hop.apply_local_map, see "
"examples/example_local_map.py for more information."
)


# NOTE: [nn.Linear decomposition]
# PyTorch currently decomposes any 3d-input nn.Linear (and matmul) into a
# sequence of view -> mm -> view operations.
# This has as a consequence of breaking any type of sharding on both the
# batch and the sequence dimension, because the flattening that happens doesn't
# allow to preserve this sharding.
# While we wait for PyTorch to avoid decomposing nn.Linear, we instead take
# the route of pattern-matching the nn.Linear specific occurences, and we replace
# them with an einsum operator.
# We perform this pattern-matching replacement for both the forward as well as
# the backward pass.
# TODO: use graph_patterns to simplify writing this
def _replace_view_mm_view_with_einsum(gm):
mm_nodes = gm.graph.find_nodes(op="call_function", target=torch.ops.aten.mm.default)
for node in mm_nodes:
first_input, second_input = node.all_input_nodes
if first_input.target == torch.ops.aten.view.default:
view_input = first_input.all_input_nodes[0]
users = list(node.users)
if (
len(users) == 1
and users[0].target == torch.ops.aten.view.default
and view_input.meta["val"].shape[:-1] == users[0].meta["val"].shape[:-1]
and second_input.meta["val"].ndim == 2
):
print(
f"Found matmul node {node}, {view_input.meta['val'].shape, second_input.meta['val'].shape}"
)
ndim = view_input.meta["val"].ndim
assert 1 < ndim <= 10, "Only support up to 10D for now"

# generate the leading dimensions as a, b, c, etc
dims = "".join([chr(97 + i) for i in range(ndim - 1)])
mm_equation = f"{dims}k,kn->{dims}n"
with gm.graph.inserting_before(node):
new_node = gm.graph.call_function(
torch.ops.aten.einsum.default,
args=(mm_equation, [view_input, second_input]),
)
new_node.meta.update(users[0].meta)
users[0].replace_all_uses_with(new_node)

elif second_input.target == torch.ops.aten.view.default:
if first_input.target != torch.ops.aten.permute.default:
continue
if first_input.all_input_nodes[0].target != torch.ops.aten.view.default:
continue
orig_first = first_input.all_input_nodes[0].all_input_nodes[0]
orig_second = second_input.all_input_nodes[0]
users = list(node.users)
if (
len(users) == 1
and users[0].target == torch.ops.aten.permute.default
and orig_first.meta["val"].shape[:-1]
== orig_second.meta["val"].shape[:-1]
and node.meta["val"].ndim == 2
):
print(
f"Found matmul node {node} {orig_first.meta['val'].shape, orig_second.meta['val'].shape}"
)

ndim = orig_first.meta["val"].ndim
assert 1 < ndim <= 10, "Only support up to 10D for now"

# generate the leading dimensions as a, b, c, etc
dims = "".join([chr(97 + i) for i in range(ndim - 1)])
mm_equation = f"{dims}n,{dims}k->kn"
with gm.graph.inserting_before(node):
new_node = gm.graph.call_function(
torch.ops.aten.einsum.default,
args=(mm_equation, [orig_first, orig_second]),
)
new_node.meta.update(users[0].meta)
users[0].replace_all_uses_with(new_node)
gm.graph.eliminate_dead_code()
gm.recompile()
22 changes: 22 additions & 0 deletions autoparallel/propagation_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,3 +761,25 @@ def expand_rule(mesh, op_schema_):
for remov in to_remove:
ss.redistribute_cost[0].insert(remov, math.inf)
return out_strat


@register_opschema_rule(torch.ops.aten.einsum.default)
def einsum_rule(mesh, op_schema):
from torch.distributed.tensor._op_schema import TupleStrategy
from torch.distributed.tensor._ops._matrix_ops import _mm_like_strategy

mm_equation, mat_strategy = op_schema.args_schema
assert isinstance(mm_equation, str)
assert isinstance(mat_strategy, TupleStrategy)

assert len(mat_strategy.children) == 2, "Only two args to einsum supported for now"

self_strategy, mat2_strategy = mat_strategy.children

# dispatch to mm_like_strategy
new_op_schema = OpSchema(
torch.ops.aten.einsum.default,
args_schema=(self_strategy, mat2_strategy),
kwargs_schema={},
)
return _mm_like_strategy(mm_equation, mesh, new_op_schema)