Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion autoparallel/apply_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from torch.fx.experimental.proxy_tensor import make_fx
from torch.utils._pytree import tree_flatten, tree_map_only

from .propagation_rules import TENSOR_FACTORY_OPS


def my_redistribute_local_tensor(arg, curr_spec, tgt_spec):
# if curr_spec.placements == (Shard(0), Shard(0)) and tgt_spec.placements == (
Expand Down Expand Up @@ -129,7 +131,7 @@ def call_function(self, target, args, kwargs):
new_args = self.redistribute_args(args)

# apply sharding to constructor functions as well
if target == torch.ops.aten.full.default:
if target in TENSOR_FACTORY_OPS:
val = list(new_args[0])
spec = self.sharding_placement[node].output_specs
for mesh_size, placement in zip(spec.mesh.shape, spec.placements):
Expand Down
114 changes: 94 additions & 20 deletions autoparallel/propagation_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@
propagate_shape_and_sharding,
register_op_strategy_map,
)
from torch.distributed.tensor._ops.utils import generate_redistribute_costs
from torch.distributed.tensor._ops.utils import (
generate_redistribute_costs,
is_tensor_shardable,
)
from torch.distributed.tensor.placement_types import Replicate, Shard

# TODO: move this to PyTorch
Expand All @@ -60,11 +63,15 @@ def wrapper(impl):
_op_partial_rules = {}


def register_opschema_rule(op):
def register_opschema_rule(ops):
global _op_partial_rules

def wrapper(impl):
_op_partial_rules[op] = impl
if isinstance(ops, list):
for op in ops:
_op_partial_rules[op] = impl
else:
_op_partial_rules[ops] = impl
return impl

return wrapper
Expand Down Expand Up @@ -97,7 +104,12 @@ def remove_invalid_configs(out_strat, mesh):
output_specs = strategy.output_specs
if isinstance(output_specs, DTensorSpec):
output_specs = [output_specs]
specs = list(strategy.input_specs) + list(output_specs)
if strategy.input_specs is not None:
specs = list(strategy.input_specs) + list(output_specs)
else:
# special case for ops like full, empty, which have no inputs. See further comments by `TENSOR_FACTORY_OPS`
specs = list(output_specs)

for spec in specs:
if spec is None:
continue
Expand Down Expand Up @@ -335,22 +347,84 @@ def randperm_rule(mesh, specs):
return OpStrategy([OpSpec(spec, input_specs=[spec], redistribute_cost=[[0.0]])])


@register_rule(torch.ops.aten.full.default)
def full_rule(mesh, specs):
raise NotImplementedError("Needs hardening, only tested on a few cases")
shape = specs[0]
# TODO: get the dtype
tensor_meta = _gen_tensor_meta(shape)
# TODO: I'm hard-coding this here, I'll probably need to do something else about this
placement = (Shard(0),) + (Replicate(),) * (mesh.ndim - 1)
# placement = (Replicate(),) * mesh.ndim
input_placement = (Replicate(),) * mesh.ndim
spec = DTensorSpec(mesh, placement, tensor_meta=tensor_meta)
input_spec = DTensorSpec(mesh, input_placement, tensor_meta=tensor_meta)
# return OpStrategy([OpSpec(spec, input_specs=[spec], redistribute_cost=[[0.0]])])
return OpStrategy(
[OpSpec(spec, input_specs=[input_spec], redistribute_cost=[[0.0]])]
)
# We do a few special things for factory ops
# - use the factory rule below
# - fake that they have input schemas so the solver doesn't freak out
# - convert their sizes to 'local tensor sizes' (divide by mesh dim) during ApplySharding
TENSOR_FACTORY_OPS = [
torch.ops.aten.zeros.default,
torch.ops.aten.ones.default,
torch.ops.aten.full.default,
torch.ops.aten.empty.memory_format,
torch.ops.aten.rand.default,
torch.ops.aten.randn.default,
]


@register_opschema_rule(TENSOR_FACTORY_OPS)
def factory_rule(mesh, op_schema: OpSchema) -> OpStrategy:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could replace / unify this implementation with the implementation in _create_all_options https://github.com/pytorch-labs/autoparallel/blob/b53ad103b2054177db1c0ac50d0b0021a5b8bb57/autoparallel/propagation_rules.py#L119-L159 ?

Can be left for the future, just bringing this up as I think they are quite similar

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i wasn't aware of these, agree we can refactor these to remove the duplication. I can do that in a follow up PR. A few questions

  1. what is "_create_all_options_no_nested_sharding" intended for? It is currently unused so i might delete it unless you have a use case in mind
  2. "_create_all_options" is pretty close to 'factory rule'. I can probably make factory_rule delegate to _create_all_options.
  3. _create_all_options is a bit vague and the todo suggests adding partial. but factories would not want partial. I'll think a bit more about how to factor the function

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

About your points:

what is "_create_all_options_no_nested_sharding" intended for? It is currently unused so i might delete it unless you have a use case in mind

I think we can remove it. It was the first implementation that I did because I was learning DTensor on the go and was using DTensorSpec.from_dim_map to generate the specs. But then I realized that it couldn't generate nested shardings like S(0)S(0), so I wrote the second function. I didn't know back then if this would be useful or not so I kept it just in case, but I think we can remove it now

"_create_all_options" is pretty close to 'factory rule'. I can probably make factory_rule delegate to _create_all_options.

Yes, I think it would be a good thing to unify both

_create_all_options is a bit vague and the todo suggests adding partial. but factories would not want partial. I'll think a bit more about how to factor the function

That comment is legacy from _create_all_options_no_nested_sharding which went into _create_all_options -- I don't think we want Partial support anymore, so it's basically the same thing as factory_rule

"""
This is an auto-parallel specific util that won't be upstreamed becuase of a UX mismatch.

In regular DTensor programs, a user has to either call `torch.full` to get a regular tensor, or
`torch.distributed.tensor.full` (with placements specified) to get a DTensor.

There is no point registering a strategy in DTensor for factories like 'full' since there is no way they
could be used by DTensor's dispatching logic. (Note: DTensor does provide strategies for similar ops like
'new_full' and 'full_like', the difference being there is an input tensor to trigger dispatch off of and to
use to direct the placement options.)

This util applies to any factory function that takes 'size' as the first argument,
and supports Replication and Shard placements all at zero cost.
"""
assert isinstance(op_schema.args_schema[0], torch.Size)
shape = op_schema.args_schema[0]
x = torch.empty(shape, device="meta")
stride = x.stride()
dtype = torch.get_default_dtype()
if len(op_schema.args_schema) >= 3:
assert isinstance(op_schema.args_schema[2], torch.dtype)
dtype = op_schema.args_schema[2]
assert isinstance(dtype, torch.dtype), dtype

# TODO: ensure the solver knows that it is more expensive to Replicate factory functions than shard
# for now, put replicate last since this might encourage sharding. (Experimentally it seemed so, but definitely
# this is not a stable gaurantee and we should fix this properly.)
single_mesh_dim_strategies = [[Shard(i)] for i in range(len(shape))] + [
[Replicate()]
]

"""
Expand the single_mesh_dim_strategies to full mesh dim strategies.
see docs for `expand_to_full_mesh_op_strategy` in _tensor_ops.py in pytorch
"""
all_mesh_dim_strategies = [single_mesh_dim_strategies] * mesh.ndim

strategy_combs = list(itertools.product(*all_mesh_dim_strategies))

all_strategies = []
for strategy_comb in strategy_combs:
spec_list = [DTensorSpec(mesh, specs) for specs in zip(*strategy_comb)]
output_specs = spec_list[0]
output_specs.tensor_meta = TensorMeta(shape, stride, dtype)

if not is_tensor_shardable(shape, output_specs):
continue

redistribute_cost = [
# TODO: there shouldn't actually be a row here, since there is no input to the op and the rows correspond
# to the inputs. However, the optimization code is not set up to tolerate input-less ops, so hack around it
# (see "/data/users/whc/autoparallel/autoparallel/optimize_sharding.py", line 226, in walk_over_options)
Comment on lines +415 to +417
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree with the comment and the current solution you implemented is what I would recommend doing as well.

[0.0]
* len(strategy_combs)
]

strategy = OpSpec(
output_specs=output_specs,
redistribute_cost=redistribute_cost,
)
all_strategies.append(strategy)
return OpStrategy(all_strategies)


# ======================================
Expand Down
11 changes: 10 additions & 1 deletion autoparallel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
from torch.distributed.tensor._ops.utils import generate_redistribute_costs
from torch.utils._pytree import tree_flatten, tree_map_only

from .propagation_rules import _op_partial_rules, _op_rules, remove_invalid_configs
from .propagation_rules import (
TENSOR_FACTORY_OPS,
_op_partial_rules,
_op_rules,
remove_invalid_configs,
)


def propagate_tensor_meta(op, user_args, user_kwargs, out_strat):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: as a follow-up work, it might be good to disable those functions to see if we are still missing some cases in DTensor sharding propagation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

disable which, 'propagate_tensor_meta'? I can try that in another PR.

Expand Down Expand Up @@ -44,6 +49,10 @@ def propagate_tensor_meta(op, user_args, user_kwargs, out_strat):
else:
assert tm is None
if strat.input_specs is None:
if op in TENSOR_FACTORY_OPS:
# there isn't an input spec bc the op has no input!
continue

supported_ops = {
torch.ops.prims.convert_element_type.default,
torch.ops.aten.clone.default,
Expand Down