Skip to content

Commit 6870a7f

Browse files
committed
Add factory_strategy to support empty, full, ...
Moves from ad-hoc and incomplete support for a couple of these ops to supporting all of the standard factory ops with sharding support. Still needs further work to add memory constraints to encourage sharding.
1 parent 8fbdba7 commit 6870a7f

File tree

3 files changed

+119
-34
lines changed

3 files changed

+119
-34
lines changed

autoparallel/apply_sharding.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from torch.fx.experimental.proxy_tensor import make_fx
1616
from torch.utils._pytree import tree_flatten, tree_map_only
1717

18+
from .propagation_rules import TENSOR_FACTORY_OPS
19+
1820

1921
def my_redistribute_local_tensor(arg, curr_spec, tgt_spec):
2022
# if curr_spec.placements == (Shard(0), Shard(0)) and tgt_spec.placements == (
@@ -129,7 +131,7 @@ def call_function(self, target, args, kwargs):
129131
new_args = self.redistribute_args(args)
130132

131133
# apply sharding to constructor functions as well
132-
if target == torch.ops.aten.full.default:
134+
if target in TENSOR_FACTORY_OPS:
133135
val = list(new_args[0])
134136
spec = self.sharding_placement[node].output_specs
135137
for mesh_size, placement in zip(spec.mesh.shape, spec.placements):

autoparallel/propagation_rules.py

Lines changed: 104 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,10 @@
3333
propagate_shape_and_sharding,
3434
register_op_strategy_map,
3535
)
36-
from torch.distributed.tensor._ops.utils import generate_redistribute_costs
36+
from torch.distributed.tensor._ops.utils import (
37+
generate_redistribute_costs,
38+
is_tensor_shardable,
39+
)
3740
from torch.distributed.tensor.placement_types import Replicate, Shard
3841

3942
# TODO: move this to PyTorch
@@ -60,11 +63,15 @@ def wrapper(impl):
6063
_op_partial_rules = {}
6164

6265

63-
def register_opschema_rule(op):
66+
def register_opschema_rule(ops):
6467
global _op_partial_rules
6568

6669
def wrapper(impl):
67-
_op_partial_rules[op] = impl
70+
if isinstance(ops, list):
71+
for op in ops:
72+
_op_partial_rules[op] = impl
73+
else:
74+
_op_partial_rules[ops] = impl
6875
return impl
6976

7077
return wrapper
@@ -97,7 +104,11 @@ def remove_invalid_configs(out_strat, mesh):
97104
output_specs = strategy.output_specs
98105
if isinstance(output_specs, DTensorSpec):
99106
output_specs = [output_specs]
100-
specs = list(strategy.input_specs) + list(output_specs)
107+
if strategy.input_specs is not None:
108+
specs = list(strategy.input_specs) + list(output_specs)
109+
else:
110+
specs = list(output_specs)
111+
101112
for spec in specs:
102113
if spec is None:
103114
continue
@@ -335,22 +346,82 @@ def randperm_rule(mesh, specs):
335346
return OpStrategy([OpSpec(spec, input_specs=[spec], redistribute_cost=[[0.0]])])
336347

337348

338-
@register_rule(torch.ops.aten.full.default)
339-
def full_rule(mesh, specs):
340-
raise NotImplementedError("Needs hardening, only tested on a few cases")
341-
shape = specs[0]
342-
# TODO: get the dtype
343-
tensor_meta = _gen_tensor_meta(shape)
344-
# TODO: I'm hard-coding this here, I'll probably need to do something else about this
345-
placement = (Shard(0),) + (Replicate(),) * (mesh.ndim - 1)
346-
# placement = (Replicate(),) * mesh.ndim
347-
input_placement = (Replicate(),) * mesh.ndim
348-
spec = DTensorSpec(mesh, placement, tensor_meta=tensor_meta)
349-
input_spec = DTensorSpec(mesh, input_placement, tensor_meta=tensor_meta)
350-
# return OpStrategy([OpSpec(spec, input_specs=[spec], redistribute_cost=[[0.0]])])
351-
return OpStrategy(
352-
[OpSpec(spec, input_specs=[input_spec], redistribute_cost=[[0.0]])]
353-
)
349+
# We do a few special things for factory ops
350+
# - use the factory rule below
351+
# - fake that they have input schemas so the solver doesn't freak out
352+
# - convert their sizes to 'local tensor sizes' (divide by mesh dim) during ApplySharding
353+
TENSOR_FACTORY_OPS = [
354+
torch.ops.aten.zeros.default,
355+
torch.ops.aten.ones.default,
356+
torch.ops.aten.full.default,
357+
torch.ops.aten.empty.memory_format,
358+
torch.ops.aten.rand.default,
359+
torch.ops.aten.randn.default,
360+
]
361+
362+
363+
@register_opschema_rule(TENSOR_FACTORY_OPS)
364+
def factory_rule(mesh, op_schema: OpSchema) -> OpStrategy:
365+
"""
366+
This is an auto-parallel specific util that won't be upstreamed becuase of a UX mismatch.
367+
368+
In regular DTensor programs, a user has to either call `torch.full` to get a regular tensor, or
369+
`torch.distributed.tensor.full` (with placements specified) to get a DTensor.
370+
371+
There is no point registering a strategy in DTensor for factories like 'full' since there is no way they
372+
could be used by DTensor's dispatching logic. (Note: DTensor does provide strategies for similar ops like
373+
'new_full' and 'full_like', the difference being there is an input tensor to trigger dispatch off of and to
374+
use to direct the placement options.)
375+
376+
This util applies to any factory function that takes 'size' as the first argument,
377+
and supports Replication and Shard placements all at zero cost.
378+
"""
379+
shape = op_schema.args_schema[0]
380+
x = torch.empty(shape, device="meta")
381+
stride = x.stride()
382+
dtype = torch.get_default_dtype()
383+
if len(op_schema.args_schema) >= 3:
384+
# Todo didn't really verify this
385+
dtype = op_schema.args_schema[2]
386+
assert isinstance(dtype, torch.dtype), dtype
387+
388+
# TODO: ensure the solver knows that it is more expensive to Replicate factory functions than shard
389+
# for now, put replicate last since this might encourage sharding :?
390+
single_mesh_dim_strategies = [[Shard(i)] for i in range(len(shape))] + [
391+
[Replicate()]
392+
]
393+
394+
"""
395+
Expand the single_mesh_dim_strategies to full mesh dim strategies.
396+
see docs for `expand_to_full_mesh_op_strategy` in _tensor_ops.py in pytorch
397+
"""
398+
all_mesh_dim_strategies = [single_mesh_dim_strategies] * mesh.ndim
399+
400+
strategy_combs = list(itertools.product(*all_mesh_dim_strategies))
401+
402+
all_strategies = []
403+
for strategy_comb in strategy_combs:
404+
spec_list = [DTensorSpec(mesh, specs) for specs in zip(*strategy_comb)]
405+
output_specs = spec_list[0]
406+
output_specs.tensor_meta = TensorMeta(shape, stride, dtype)
407+
408+
if not is_tensor_shardable(shape, output_specs):
409+
continue
410+
411+
redistribute_cost = [
412+
# TODO: there shouldn't actually be a row here, since there is no input to the op and the rows correspond
413+
# to the inputs. However, the optimization code is not set up to tolerate input-less ops, so hack around it
414+
# (see "/data/users/whc/autoparallel/autoparallel/optimize_sharding.py", line 226, in walk_over_options)
415+
[0.0]
416+
* len(strategy_combs)
417+
]
418+
419+
strategy = OpSpec(
420+
output_specs=output_specs,
421+
redistribute_cost=redistribute_cost,
422+
)
423+
all_strategies.append(strategy)
424+
return OpStrategy(all_strategies)
354425

355426

356427
# ======================================
@@ -603,18 +674,19 @@ def index_put_rule(mesh, op_schema):
603674
t_strats = [DTensorSpec(mesh, placements=ispec.placements)]
604675
s = OpSpec(output_specs=ospec, input_specs=[ispec] + idxs_strats + t_strats)
605676

606-
redistribute_costs = (
607-
[generate_redistribute_costs(specs[0], ospec)]
608-
+ [
609-
generate_redistribute_costs(kk, idxs_strat)
610-
for kk, idxs_strat in zip(kspc, idxs_strats)
611-
]
612-
+ [generate_redistribute_costs(specs[2], t_strats[0])]
613-
)
614-
s.redistribute_cost = redistribute_costs
615-
res.append(s)
616-
out_strat = OpStrategy(res)
617-
return out_strat
677+
678+
# redistribute_costs = (
679+
# [generate_redistribute_costs(specs[0], ospec)]
680+
# + [
681+
# generate_redistribute_costs(kk, idxs_strat)
682+
# for kk, idxs_strat in zip(kspc, idxs_strats)
683+
# ]
684+
# + [generate_redistribute_costs(specs[2], t_strats[0])]
685+
# )
686+
# s.redistribute_cost = redistribute_costs
687+
# res.append(s)
688+
# out_strat = OpStrategy(res)
689+
# return out_strat
618690

619691

620692
@register_opschema_rule(torch.ops.aten._scaled_dot_product_efficient_attention.default)

autoparallel/utils.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,16 @@
33
# This source code is licensed under the BSD license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
from re import I
7+
68
import torch
79
from torch.distributed._tensor.placement_types import TensorMeta
810
from torch.distributed.device_mesh import _get_device_handle
911
from torch.distributed.tensor._op_schema import OpSchema, OpStrategy, TupleStrategy
1012
from torch.distributed.tensor._ops.utils import generate_redistribute_costs
1113
from torch.utils._pytree import tree_flatten, tree_map_only
1214

13-
from .propagation_rules import _op_partial_rules, _op_rules, remove_invalid_configs
15+
from .propagation_rules import _op_partial_rules, _op_rules, remove_invalid_configs, TENSOR_FACTORY_OPS
1416

1517

1618
def propagate_tensor_meta(op, user_args, user_kwargs, out_strat):
@@ -44,6 +46,15 @@ def propagate_tensor_meta(op, user_args, user_kwargs, out_strat):
4446
else:
4547
assert tm is None
4648
if strat.input_specs is None:
49+
if op in TENSOR_FACTORY_OPS:
50+
# there isn't an input spec bc the op has no input!
51+
# continue
52+
53+
# but index_put op insists on looking at 'input_specs' of its input, which seems absurd.
54+
# so just copy it for now and fix later
55+
strat.input_specs = (strat.output_specs,)
56+
continue
57+
4758
supported_ops = {
4859
torch.ops.prims.convert_element_type.default,
4960
torch.ops.aten.clone.default,

0 commit comments

Comments
 (0)