Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions autoparallel/compute_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.

import torch
from torch.utils._pytree import tree_map_only
from torch.utils._pytree import tree_flatten, tree_map_only
from torch.utils.flop_counter import FlopCounterMode


Expand Down Expand Up @@ -59,12 +59,20 @@ def estimate_strategy_runtime_cost(node, strategy):

args = tree_map_only(torch.fx.Node, lambda x: x.meta["val"], node.args)
kwargs = tree_map_only(torch.fx.Node, lambda x: x.meta["val"], node.kwargs)
fake_mode = next(

fake_modes = [
arg.fake_mode
for arg in args
for arg in tree_flatten(args)[0]
if isinstance(arg, torch._subclasses.fake_tensor.FakeTensor)
)
assert len(kwargs) == 0
]
if len(fake_modes) == 0:
return 0

assert all(fm == fake_modes[0] for fm in fake_modes)
fake_mode = fake_modes[0]
if len(kwargs) > 0:
for k, v in kwargs.items():
assert not isinstance(v, torch.Tensor), f"{node} {v}"
args_shapes = tuple(_get_sharded_shape(spec) for spec in strategy.input_specs)

counter = 0
Expand All @@ -87,6 +95,9 @@ def estimate_strategy_runtime_cost(node, strategy):
# TODO: fix this
dtype = strategy.input_specs[0].tensor_meta.dtype

# TODO: better handle this case
if dtype.is_complex:
return 0
# TODO: use PyTorch's version once it's giving correct results
gpu_flops = _get_device_tflops(dtype) * 10**12

Expand Down
9 changes: 8 additions & 1 deletion autoparallel/propagation_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,14 @@ def expand_rule(mesh, op_schema_):
for i, (s1, s2) in enumerate(zip(orig_shape, dest_shape))
if s1 == 1 and s2 != s1
]
assert len(expand_dim) == 1
if len(expand_dim) != 1:
assert len(expand_dim) == 0
return torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs[
op
](
op_schema
)
assert len(expand_dim) == 1, f"{expand_dim}"
expand_dim = expand_dim[0]
to_remove = []
for i, ss in enumerate(input_strat.strategies):
Expand Down
8 changes: 5 additions & 3 deletions autoparallel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,14 @@ def propagate_tensor_meta(op, user_args, out_strat):
else:
assert tm is None
if strat.input_specs is None:
assert op in {
supported_ops = {
torch.ops.prims.convert_element_type.default,
torch.ops.aten.clone.default,
torch.ops.aten.slice.Tensor,
}, (
}
assert op in supported_ops, (
f"{op} strategy doesn't have input_specs, only harcoded "
"prims.convert_element_type.default and aten.slice.Tensor for now"
"{supported_ops} for now"
)
strat.input_specs = (strat.output_specs,)
assert strat.redistribute_cost is None
Expand Down