Skip to content

Commit ed3312a

Browse files
authored
Miscellaneous improvements for MLA (#7)
With those improvements, AMAIA's implementation of MLA seems to work
1 parent b51cd9a commit ed3312a

File tree

3 files changed

+29
-9
lines changed

3 files changed

+29
-9
lines changed

autoparallel/compute_estimation.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import torch
7-
from torch.utils._pytree import tree_map_only
7+
from torch.utils._pytree import tree_flatten, tree_map_only
88
from torch.utils.flop_counter import FlopCounterMode
99

1010

@@ -59,12 +59,20 @@ def estimate_strategy_runtime_cost(node, strategy):
5959

6060
args = tree_map_only(torch.fx.Node, lambda x: x.meta["val"], node.args)
6161
kwargs = tree_map_only(torch.fx.Node, lambda x: x.meta["val"], node.kwargs)
62-
fake_mode = next(
62+
63+
fake_modes = [
6364
arg.fake_mode
64-
for arg in args
65+
for arg in tree_flatten(args)[0]
6566
if isinstance(arg, torch._subclasses.fake_tensor.FakeTensor)
66-
)
67-
assert len(kwargs) == 0
67+
]
68+
if len(fake_modes) == 0:
69+
return 0
70+
71+
assert all(fm == fake_modes[0] for fm in fake_modes)
72+
fake_mode = fake_modes[0]
73+
if len(kwargs) > 0:
74+
for k, v in kwargs.items():
75+
assert not isinstance(v, torch.Tensor), f"{node} {v}"
6876
args_shapes = tuple(_get_sharded_shape(spec) for spec in strategy.input_specs)
6977

7078
counter = 0
@@ -87,6 +95,9 @@ def estimate_strategy_runtime_cost(node, strategy):
8795
# TODO: fix this
8896
dtype = strategy.input_specs[0].tensor_meta.dtype
8997

98+
# TODO: better handle this case
99+
if dtype.is_complex:
100+
return 0
90101
# TODO: use PyTorch's version once it's giving correct results
91102
gpu_flops = _get_device_tflops(dtype) * 10**12
92103

autoparallel/propagation_rules.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,14 @@ def expand_rule(mesh, op_schema_):
678678
for i, (s1, s2) in enumerate(zip(orig_shape, dest_shape))
679679
if s1 == 1 and s2 != s1
680680
]
681-
assert len(expand_dim) == 1
681+
if len(expand_dim) != 1:
682+
assert len(expand_dim) == 0
683+
return torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs[
684+
op
685+
](
686+
op_schema
687+
)
688+
assert len(expand_dim) == 1, f"{expand_dim}"
682689
expand_dim = expand_dim[0]
683690
to_remove = []
684691
for i, ss in enumerate(input_strat.strategies):

autoparallel/utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,14 @@ def propagate_tensor_meta(op, user_args, out_strat):
4343
else:
4444
assert tm is None
4545
if strat.input_specs is None:
46-
assert op in {
46+
supported_ops = {
4747
torch.ops.prims.convert_element_type.default,
48+
torch.ops.aten.clone.default,
4849
torch.ops.aten.slice.Tensor,
49-
}, (
50+
}
51+
assert op in supported_ops, (
5052
f"{op} strategy doesn't have input_specs, only harcoded "
51-
"prims.convert_element_type.default and aten.slice.Tensor for now"
53+
"{supported_ops} for now"
5254
)
5355
strat.input_specs = (strat.output_specs,)
5456
assert strat.redistribute_cost is None

0 commit comments

Comments
 (0)