-
Notifications
You must be signed in to change notification settings - Fork 8
Add factory_strategy to support empty, full, ... #44
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -33,7 +33,10 @@ | |
| propagate_shape_and_sharding, | ||
| register_op_strategy_map, | ||
| ) | ||
| from torch.distributed.tensor._ops.utils import generate_redistribute_costs | ||
| from torch.distributed.tensor._ops.utils import ( | ||
| generate_redistribute_costs, | ||
| is_tensor_shardable, | ||
| ) | ||
| from torch.distributed.tensor.placement_types import Replicate, Shard | ||
|
|
||
| # TODO: move this to PyTorch | ||
|
|
@@ -60,11 +63,15 @@ def wrapper(impl): | |
| _op_partial_rules = {} | ||
|
|
||
|
|
||
| def register_opschema_rule(op): | ||
| def register_opschema_rule(ops): | ||
| global _op_partial_rules | ||
|
|
||
| def wrapper(impl): | ||
| _op_partial_rules[op] = impl | ||
| if isinstance(ops, list): | ||
| for op in ops: | ||
| _op_partial_rules[op] = impl | ||
| else: | ||
| _op_partial_rules[ops] = impl | ||
| return impl | ||
|
|
||
| return wrapper | ||
|
|
@@ -97,7 +104,12 @@ def remove_invalid_configs(out_strat, mesh): | |
| output_specs = strategy.output_specs | ||
| if isinstance(output_specs, DTensorSpec): | ||
| output_specs = [output_specs] | ||
| specs = list(strategy.input_specs) + list(output_specs) | ||
| if strategy.input_specs is not None: | ||
| specs = list(strategy.input_specs) + list(output_specs) | ||
| else: | ||
| # special case for ops like full, empty, which have no inputs. See further comments by `TENSOR_FACTORY_OPS` | ||
| specs = list(output_specs) | ||
|
|
||
| for spec in specs: | ||
| if spec is None: | ||
| continue | ||
|
|
@@ -335,22 +347,84 @@ def randperm_rule(mesh, specs): | |
| return OpStrategy([OpSpec(spec, input_specs=[spec], redistribute_cost=[[0.0]])]) | ||
|
|
||
|
|
||
| @register_rule(torch.ops.aten.full.default) | ||
| def full_rule(mesh, specs): | ||
| raise NotImplementedError("Needs hardening, only tested on a few cases") | ||
| shape = specs[0] | ||
| # TODO: get the dtype | ||
| tensor_meta = _gen_tensor_meta(shape) | ||
| # TODO: I'm hard-coding this here, I'll probably need to do something else about this | ||
| placement = (Shard(0),) + (Replicate(),) * (mesh.ndim - 1) | ||
| # placement = (Replicate(),) * mesh.ndim | ||
| input_placement = (Replicate(),) * mesh.ndim | ||
| spec = DTensorSpec(mesh, placement, tensor_meta=tensor_meta) | ||
| input_spec = DTensorSpec(mesh, input_placement, tensor_meta=tensor_meta) | ||
| # return OpStrategy([OpSpec(spec, input_specs=[spec], redistribute_cost=[[0.0]])]) | ||
| return OpStrategy( | ||
| [OpSpec(spec, input_specs=[input_spec], redistribute_cost=[[0.0]])] | ||
| ) | ||
| # We do a few special things for factory ops | ||
| # - use the factory rule below | ||
| # - fake that they have input schemas so the solver doesn't freak out | ||
| # - convert their sizes to 'local tensor sizes' (divide by mesh dim) during ApplySharding | ||
| TENSOR_FACTORY_OPS = [ | ||
| torch.ops.aten.zeros.default, | ||
| torch.ops.aten.ones.default, | ||
| torch.ops.aten.full.default, | ||
| torch.ops.aten.empty.memory_format, | ||
| torch.ops.aten.rand.default, | ||
| torch.ops.aten.randn.default, | ||
| ] | ||
|
|
||
|
|
||
| @register_opschema_rule(TENSOR_FACTORY_OPS) | ||
| def factory_rule(mesh, op_schema: OpSchema) -> OpStrategy: | ||
| """ | ||
| This is an auto-parallel specific util that won't be upstreamed becuase of a UX mismatch. | ||
|
|
||
| In regular DTensor programs, a user has to either call `torch.full` to get a regular tensor, or | ||
| `torch.distributed.tensor.full` (with placements specified) to get a DTensor. | ||
|
|
||
| There is no point registering a strategy in DTensor for factories like 'full' since there is no way they | ||
| could be used by DTensor's dispatching logic. (Note: DTensor does provide strategies for similar ops like | ||
| 'new_full' and 'full_like', the difference being there is an input tensor to trigger dispatch off of and to | ||
| use to direct the placement options.) | ||
|
|
||
| This util applies to any factory function that takes 'size' as the first argument, | ||
| and supports Replication and Shard placements all at zero cost. | ||
| """ | ||
| assert isinstance(op_schema.args_schema[0], torch.Size) | ||
| shape = op_schema.args_schema[0] | ||
| x = torch.empty(shape, device="meta") | ||
| stride = x.stride() | ||
| dtype = torch.get_default_dtype() | ||
| if len(op_schema.args_schema) >= 3: | ||
| assert isinstance(op_schema.args_schema[2], torch.dtype) | ||
| dtype = op_schema.args_schema[2] | ||
| assert isinstance(dtype, torch.dtype), dtype | ||
|
|
||
| # TODO: ensure the solver knows that it is more expensive to Replicate factory functions than shard | ||
| # for now, put replicate last since this might encourage sharding. (Experimentally it seemed so, but definitely | ||
| # this is not a stable gaurantee and we should fix this properly.) | ||
| single_mesh_dim_strategies = [[Shard(i)] for i in range(len(shape))] + [ | ||
| [Replicate()] | ||
| ] | ||
|
|
||
| """ | ||
| Expand the single_mesh_dim_strategies to full mesh dim strategies. | ||
| see docs for `expand_to_full_mesh_op_strategy` in _tensor_ops.py in pytorch | ||
| """ | ||
| all_mesh_dim_strategies = [single_mesh_dim_strategies] * mesh.ndim | ||
|
|
||
| strategy_combs = list(itertools.product(*all_mesh_dim_strategies)) | ||
|
|
||
| all_strategies = [] | ||
| for strategy_comb in strategy_combs: | ||
| spec_list = [DTensorSpec(mesh, specs) for specs in zip(*strategy_comb)] | ||
| output_specs = spec_list[0] | ||
| output_specs.tensor_meta = TensorMeta(shape, stride, dtype) | ||
|
|
||
| if not is_tensor_shardable(shape, output_specs): | ||
| continue | ||
|
|
||
| redistribute_cost = [ | ||
| # TODO: there shouldn't actually be a row here, since there is no input to the op and the rows correspond | ||
| # to the inputs. However, the optimization code is not set up to tolerate input-less ops, so hack around it | ||
| # (see "/data/users/whc/autoparallel/autoparallel/optimize_sharding.py", line 226, in walk_over_options) | ||
|
Comment on lines
+415
to
+417
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I agree with the comment and the current solution you implemented is what I would recommend doing as well. |
||
| [0.0] | ||
| * len(strategy_combs) | ||
| ] | ||
|
|
||
| strategy = OpSpec( | ||
| output_specs=output_specs, | ||
| redistribute_cost=redistribute_cost, | ||
| ) | ||
| all_strategies.append(strategy) | ||
| return OpStrategy(all_strategies) | ||
|
|
||
|
|
||
| # ====================================== | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,7 +10,12 @@ | |
| from torch.distributed.tensor._ops.utils import generate_redistribute_costs | ||
| from torch.utils._pytree import tree_flatten, tree_map_only | ||
|
|
||
| from .propagation_rules import _op_partial_rules, _op_rules, remove_invalid_configs | ||
| from .propagation_rules import ( | ||
| TENSOR_FACTORY_OPS, | ||
| _op_partial_rules, | ||
| _op_rules, | ||
| remove_invalid_configs, | ||
| ) | ||
|
|
||
|
|
||
| def propagate_tensor_meta(op, user_args, user_kwargs, out_strat): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: as a follow-up work, it might be good to disable those functions to see if we are still missing some cases in DTensor sharding propagation
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. disable which, 'propagate_tensor_meta'? I can try that in another PR. |
||
|
|
@@ -44,6 +49,10 @@ def propagate_tensor_meta(op, user_args, user_kwargs, out_strat): | |
| else: | ||
| assert tm is None | ||
| if strat.input_specs is None: | ||
| if op in TENSOR_FACTORY_OPS: | ||
| # there isn't an input spec bc the op has no input! | ||
| continue | ||
|
|
||
| supported_ops = { | ||
| torch.ops.prims.convert_element_type.default, | ||
| torch.ops.aten.clone.default, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we could replace / unify this implementation with the implementation in
_create_all_optionshttps://github.com/pytorch-labs/autoparallel/blob/b53ad103b2054177db1c0ac50d0b0021a5b8bb57/autoparallel/propagation_rules.py#L119-L159 ?Can be left for the future, just bringing this up as I think they are quite similar
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i wasn't aware of these, agree we can refactor these to remove the duplication. I can do that in a follow up PR. A few questions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
About your points:
I think we can remove it. It was the first implementation that I did because I was learning DTensor on the go and was using
DTensorSpec.from_dim_mapto generate the specs. But then I realized that it couldn't generate nested shardings likeS(0)S(0), so I wrote the second function. I didn't know back then if this would be useful or not so I kept it just in case, but I think we can remove it nowYes, I think it would be a good thing to unify both
That comment is legacy from
_create_all_options_no_nested_shardingwhich went into_create_all_options-- I don't think we wantPartialsupport anymore, so it's basically the same thing asfactory_rule