Skip to content

Commit 5b3cacf

Browse files
committed
Add factory_strategy to support empty, full, ...
Moves from ad-hoc and incomplete support for a couple of these ops to supporting all of the standard factory ops with sharding support. Still needs further work to add memory constraints to encourage sharding.
1 parent b53ad10 commit 5b3cacf

File tree

3 files changed

+107
-22
lines changed

3 files changed

+107
-22
lines changed

autoparallel/apply_sharding.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from torch.fx.experimental.proxy_tensor import make_fx
1616
from torch.utils._pytree import tree_flatten, tree_map_only
1717

18+
from .propagation_rules import TENSOR_FACTORY_OPS
19+
1820

1921
def my_redistribute_local_tensor(arg, curr_spec, tgt_spec):
2022
# if curr_spec.placements == (Shard(0), Shard(0)) and tgt_spec.placements == (
@@ -129,7 +131,7 @@ def call_function(self, target, args, kwargs):
129131
new_args = self.redistribute_args(args)
130132

131133
# apply sharding to constructor functions as well
132-
if target == torch.ops.aten.full.default:
134+
if target in TENSOR_FACTORY_OPS:
133135
val = list(new_args[0])
134136
spec = self.sharding_placement[node].output_specs
135137
for mesh_size, placement in zip(spec.mesh.shape, spec.placements):

autoparallel/propagation_rules.py

Lines changed: 94 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,10 @@
3333
propagate_shape_and_sharding,
3434
register_op_strategy_map,
3535
)
36-
from torch.distributed.tensor._ops.utils import generate_redistribute_costs
36+
from torch.distributed.tensor._ops.utils import (
37+
generate_redistribute_costs,
38+
is_tensor_shardable,
39+
)
3740
from torch.distributed.tensor.placement_types import Replicate, Shard
3841

3942
# TODO: move this to PyTorch
@@ -60,11 +63,15 @@ def wrapper(impl):
6063
_op_partial_rules = {}
6164

6265

63-
def register_opschema_rule(op):
66+
def register_opschema_rule(ops):
6467
global _op_partial_rules
6568

6669
def wrapper(impl):
67-
_op_partial_rules[op] = impl
70+
if isinstance(ops, list):
71+
for op in ops:
72+
_op_partial_rules[op] = impl
73+
else:
74+
_op_partial_rules[ops] = impl
6875
return impl
6976

7077
return wrapper
@@ -97,7 +104,12 @@ def remove_invalid_configs(out_strat, mesh):
97104
output_specs = strategy.output_specs
98105
if isinstance(output_specs, DTensorSpec):
99106
output_specs = [output_specs]
100-
specs = list(strategy.input_specs) + list(output_specs)
107+
if strategy.input_specs is not None:
108+
specs = list(strategy.input_specs) + list(output_specs)
109+
else:
110+
# special case for ops like full, empty, which have no inputs. See further comments by `TENSOR_FACTORY_OPS`
111+
specs = list(output_specs)
112+
101113
for spec in specs:
102114
if spec is None:
103115
continue
@@ -335,22 +347,84 @@ def randperm_rule(mesh, specs):
335347
return OpStrategy([OpSpec(spec, input_specs=[spec], redistribute_cost=[[0.0]])])
336348

337349

338-
@register_rule(torch.ops.aten.full.default)
339-
def full_rule(mesh, specs):
340-
raise NotImplementedError("Needs hardening, only tested on a few cases")
341-
shape = specs[0]
342-
# TODO: get the dtype
343-
tensor_meta = _gen_tensor_meta(shape)
344-
# TODO: I'm hard-coding this here, I'll probably need to do something else about this
345-
placement = (Shard(0),) + (Replicate(),) * (mesh.ndim - 1)
346-
# placement = (Replicate(),) * mesh.ndim
347-
input_placement = (Replicate(),) * mesh.ndim
348-
spec = DTensorSpec(mesh, placement, tensor_meta=tensor_meta)
349-
input_spec = DTensorSpec(mesh, input_placement, tensor_meta=tensor_meta)
350-
# return OpStrategy([OpSpec(spec, input_specs=[spec], redistribute_cost=[[0.0]])])
351-
return OpStrategy(
352-
[OpSpec(spec, input_specs=[input_spec], redistribute_cost=[[0.0]])]
353-
)
350+
# We do a few special things for factory ops
351+
# - use the factory rule below
352+
# - fake that they have input schemas so the solver doesn't freak out
353+
# - convert their sizes to 'local tensor sizes' (divide by mesh dim) during ApplySharding
354+
TENSOR_FACTORY_OPS = [
355+
torch.ops.aten.zeros.default,
356+
torch.ops.aten.ones.default,
357+
torch.ops.aten.full.default,
358+
torch.ops.aten.empty.memory_format,
359+
torch.ops.aten.rand.default,
360+
torch.ops.aten.randn.default,
361+
]
362+
363+
364+
@register_opschema_rule(TENSOR_FACTORY_OPS)
365+
def factory_rule(mesh, op_schema: OpSchema) -> OpStrategy:
366+
"""
367+
This is an auto-parallel specific util that won't be upstreamed becuase of a UX mismatch.
368+
369+
In regular DTensor programs, a user has to either call `torch.full` to get a regular tensor, or
370+
`torch.distributed.tensor.full` (with placements specified) to get a DTensor.
371+
372+
There is no point registering a strategy in DTensor for factories like 'full' since there is no way they
373+
could be used by DTensor's dispatching logic. (Note: DTensor does provide strategies for similar ops like
374+
'new_full' and 'full_like', the difference being there is an input tensor to trigger dispatch off of and to
375+
use to direct the placement options.)
376+
377+
This util applies to any factory function that takes 'size' as the first argument,
378+
and supports Replication and Shard placements all at zero cost.
379+
"""
380+
assert isinstance(op_schema.args_schema[0], torch.Size)
381+
shape = op_schema.args_schema[0]
382+
x = torch.empty(shape, device="meta")
383+
stride = x.stride()
384+
dtype = torch.get_default_dtype()
385+
if len(op_schema.args_schema) >= 3:
386+
assert isinstance(op_schema.args_schema[2], torch.dtype)
387+
dtype = op_schema.args_schema[2]
388+
assert isinstance(dtype, torch.dtype), dtype
389+
390+
# TODO: ensure the solver knows that it is more expensive to Replicate factory functions than shard
391+
# for now, put replicate last since this might encourage sharding. (Experimentally it seemed so, but definitely
392+
# this is not a stable gaurantee and we should fix this properly.)
393+
single_mesh_dim_strategies = [[Shard(i)] for i in range(len(shape))] + [
394+
[Replicate()]
395+
]
396+
397+
"""
398+
Expand the single_mesh_dim_strategies to full mesh dim strategies.
399+
see docs for `expand_to_full_mesh_op_strategy` in _tensor_ops.py in pytorch
400+
"""
401+
all_mesh_dim_strategies = [single_mesh_dim_strategies] * mesh.ndim
402+
403+
strategy_combs = list(itertools.product(*all_mesh_dim_strategies))
404+
405+
all_strategies = []
406+
for strategy_comb in strategy_combs:
407+
spec_list = [DTensorSpec(mesh, specs) for specs in zip(*strategy_comb)]
408+
output_specs = spec_list[0]
409+
output_specs.tensor_meta = TensorMeta(shape, stride, dtype)
410+
411+
if not is_tensor_shardable(shape, output_specs):
412+
continue
413+
414+
redistribute_cost = [
415+
# TODO: there shouldn't actually be a row here, since there is no input to the op and the rows correspond
416+
# to the inputs. However, the optimization code is not set up to tolerate input-less ops, so hack around it
417+
# (see "/data/users/whc/autoparallel/autoparallel/optimize_sharding.py", line 226, in walk_over_options)
418+
[0.0]
419+
* len(strategy_combs)
420+
]
421+
422+
strategy = OpSpec(
423+
output_specs=output_specs,
424+
redistribute_cost=redistribute_cost,
425+
)
426+
all_strategies.append(strategy)
427+
return OpStrategy(all_strategies)
354428

355429

356430
# ======================================

autoparallel/utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,12 @@
1010
from torch.distributed.tensor._ops.utils import generate_redistribute_costs
1111
from torch.utils._pytree import tree_flatten, tree_map_only
1212

13-
from .propagation_rules import _op_partial_rules, _op_rules, remove_invalid_configs
13+
from .propagation_rules import (
14+
TENSOR_FACTORY_OPS,
15+
_op_partial_rules,
16+
_op_rules,
17+
remove_invalid_configs,
18+
)
1419

1520

1621
def propagate_tensor_meta(op, user_args, user_kwargs, out_strat):
@@ -44,6 +49,10 @@ def propagate_tensor_meta(op, user_args, user_kwargs, out_strat):
4449
else:
4550
assert tm is None
4651
if strat.input_specs is None:
52+
if op in TENSOR_FACTORY_OPS:
53+
# there isn't an input spec bc the op has no input!
54+
continue
55+
4756
supported_ops = {
4857
torch.ops.prims.convert_element_type.default,
4958
torch.ops.aten.clone.default,

0 commit comments

Comments
 (0)