diff --git a/autoparallel/api.py b/autoparallel/api.py index c664aa49..39d239d6 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -29,6 +29,7 @@ from .cast_parametrization import apply_dtype_cast, canonicalize_mp, set_dtype_cast from .graph_utils import ( _add_alias, + _replace_view_mm_view_with_einsum, assert_has_no_collectives, cleanup_graph, update_joint_with_descriptors, @@ -37,6 +38,8 @@ from .optimize_sharding import ShardingOptimizer from .utils import _get_device_from_mesh +_APPLY_VIEW_MM_VIEW_PATTERN = False + def try_convert_fake_to_real(tensors): out = {} @@ -230,6 +233,8 @@ def build_model_graph(self): assert_has_no_collectives(gm) cleanup_graph(gm) + if _APPLY_VIEW_MM_VIEW_PATTERN: + _replace_view_mm_view_with_einsum(gm) # now add aliases nodes to the graph to # give more room for optimizations _add_alias(gm) diff --git a/autoparallel/compute_estimation.py b/autoparallel/compute_estimation.py index 3d8678cb..3012a280 100644 --- a/autoparallel/compute_estimation.py +++ b/autoparallel/compute_estimation.py @@ -8,7 +8,34 @@ import torch from torch.utils._pytree import tree_flatten, tree_map_only -from torch.utils.flop_counter import FlopCounterMode +from torch.utils.flop_counter import FlopCounterMode, register_flop_formula + + +@register_flop_formula(torch.ops.aten.einsum, get_raw=True) +def einsum_flop(equation, tensors, out=None, **kwargs) -> int: + # from torch.distributed.tensor._ops._einsum_strategy import EinsumDims + assert len(tensors) == 2 + a_shape, b_shape = [x.shape for x in tensors] + + # parse einop equation and extract dims + # TODO: generalize + # input_dims, output_dim = EinsumDims.parse_equation(equation) + # edims = EinsumDims.parse_dims(input_dims, output_dim) + + if len(a_shape) == 3 and len(b_shape) == 3: + b, m, k = a_shape + b1, n, k2 = b_shape + assert b == b1 + assert m == n + flop = (b * m) * k * k2 * 2 + elif len(a_shape) == 3 and len(b_shape) == 2: + b, m, k = a_shape + k2, n = b_shape + assert k == k2 + flop = b * m * n * k * 2 + else: + raise NotImplementedError(f"Unsupported einsum shapes: {a_shape} {b_shape}") + return flop @dataclass diff --git a/autoparallel/graph_utils.py b/autoparallel/graph_utils.py index 03383f69..e5f8d663 100644 --- a/autoparallel/graph_utils.py +++ b/autoparallel/graph_utils.py @@ -153,3 +153,81 @@ def assert_has_no_collectives(gm: torch.fx.GraphModule): f"autoparallel.local_map_hop.apply_local_map, see " "examples/example_local_map.py for more information." ) + + +# NOTE: [nn.Linear decomposition] +# PyTorch currently decomposes any 3d-input nn.Linear (and matmul) into a +# sequence of view -> mm -> view operations. +# This has as a consequence of breaking any type of sharding on both the +# batch and the sequence dimension, because the flattening that happens doesn't +# allow to preserve this sharding. +# While we wait for PyTorch to avoid decomposing nn.Linear, we instead take +# the route of pattern-matching the nn.Linear specific occurences, and we replace +# them with an einsum operator. +# We perform this pattern-matching replacement for both the forward as well as +# the backward pass. +# TODO: use graph_patterns to simplify writing this +def _replace_view_mm_view_with_einsum(gm): + mm_nodes = gm.graph.find_nodes(op="call_function", target=torch.ops.aten.mm.default) + for node in mm_nodes: + first_input, second_input = node.all_input_nodes + if first_input.target == torch.ops.aten.view.default: + view_input = first_input.all_input_nodes[0] + users = list(node.users) + if ( + len(users) == 1 + and users[0].target == torch.ops.aten.view.default + and view_input.meta["val"].shape[:-1] == users[0].meta["val"].shape[:-1] + and second_input.meta["val"].ndim == 2 + ): + print( + f"Found matmul node {node}, {view_input.meta['val'].shape, second_input.meta['val'].shape}" + ) + ndim = view_input.meta["val"].ndim + assert 1 < ndim <= 10, "Only support up to 10D for now" + + # generate the leading dimensions as a, b, c, etc + dims = "".join([chr(97 + i) for i in range(ndim - 1)]) + mm_equation = f"{dims}k,kn->{dims}n" + with gm.graph.inserting_before(node): + new_node = gm.graph.call_function( + torch.ops.aten.einsum.default, + args=(mm_equation, [view_input, second_input]), + ) + new_node.meta.update(users[0].meta) + users[0].replace_all_uses_with(new_node) + + elif second_input.target == torch.ops.aten.view.default: + if first_input.target != torch.ops.aten.permute.default: + continue + if first_input.all_input_nodes[0].target != torch.ops.aten.view.default: + continue + orig_first = first_input.all_input_nodes[0].all_input_nodes[0] + orig_second = second_input.all_input_nodes[0] + users = list(node.users) + if ( + len(users) == 1 + and users[0].target == torch.ops.aten.permute.default + and orig_first.meta["val"].shape[:-1] + == orig_second.meta["val"].shape[:-1] + and node.meta["val"].ndim == 2 + ): + print( + f"Found matmul node {node} {orig_first.meta['val'].shape, orig_second.meta['val'].shape}" + ) + + ndim = orig_first.meta["val"].ndim + assert 1 < ndim <= 10, "Only support up to 10D for now" + + # generate the leading dimensions as a, b, c, etc + dims = "".join([chr(97 + i) for i in range(ndim - 1)]) + mm_equation = f"{dims}n,{dims}k->kn" + with gm.graph.inserting_before(node): + new_node = gm.graph.call_function( + torch.ops.aten.einsum.default, + args=(mm_equation, [orig_first, orig_second]), + ) + new_node.meta.update(users[0].meta) + users[0].replace_all_uses_with(new_node) + gm.graph.eliminate_dead_code() + gm.recompile() diff --git a/autoparallel/propagation_rules.py b/autoparallel/propagation_rules.py index cc5602ab..68cb42a9 100644 --- a/autoparallel/propagation_rules.py +++ b/autoparallel/propagation_rules.py @@ -761,3 +761,25 @@ def expand_rule(mesh, op_schema_): for remov in to_remove: ss.redistribute_cost[0].insert(remov, math.inf) return out_strat + + +@register_opschema_rule(torch.ops.aten.einsum.default) +def einsum_rule(mesh, op_schema): + from torch.distributed.tensor._op_schema import TupleStrategy + from torch.distributed.tensor._ops._matrix_ops import _mm_like_strategy + + mm_equation, mat_strategy = op_schema.args_schema + assert isinstance(mm_equation, str) + assert isinstance(mat_strategy, TupleStrategy) + + assert len(mat_strategy.children) == 2, "Only two args to einsum supported for now" + + self_strategy, mat2_strategy = mat_strategy.children + + # dispatch to mm_like_strategy + new_op_schema = OpSchema( + torch.ops.aten.einsum.default, + args_schema=(self_strategy, mat2_strategy), + kwargs_schema={}, + ) + return _mm_like_strategy(mm_equation, mesh, new_op_schema)