Skip to content

Commit 304ebc6

Browse files
committed
Add factory_strategy to support empty, full, ...
Moves from ad-hoc and incomplete support for a couple of these ops to supporting all of the standard factory ops with sharding support. Still needs further work to add memory constraints to encourage sharding.
1 parent 8fbdba7 commit 304ebc6

File tree

3 files changed

+120
-34
lines changed

3 files changed

+120
-34
lines changed

autoparallel/apply_sharding.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from torch.fx.experimental.proxy_tensor import make_fx
1616
from torch.utils._pytree import tree_flatten, tree_map_only
1717

18+
from .propagation_rules import TENSOR_FACTORY_OPS
19+
1820

1921
def my_redistribute_local_tensor(arg, curr_spec, tgt_spec):
2022
# if curr_spec.placements == (Shard(0), Shard(0)) and tgt_spec.placements == (
@@ -129,7 +131,7 @@ def call_function(self, target, args, kwargs):
129131
new_args = self.redistribute_args(args)
130132

131133
# apply sharding to constructor functions as well
132-
if target == torch.ops.aten.full.default:
134+
if target in TENSOR_FACTORY_OPS:
133135
val = list(new_args[0])
134136
spec = self.sharding_placement[node].output_specs
135137
for mesh_size, placement in zip(spec.mesh.shape, spec.placements):

autoparallel/propagation_rules.py

Lines changed: 105 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,10 @@
3333
propagate_shape_and_sharding,
3434
register_op_strategy_map,
3535
)
36-
from torch.distributed.tensor._ops.utils import generate_redistribute_costs
36+
from torch.distributed.tensor._ops.utils import (
37+
generate_redistribute_costs,
38+
is_tensor_shardable,
39+
)
3740
from torch.distributed.tensor.placement_types import Replicate, Shard
3841

3942
# TODO: move this to PyTorch
@@ -60,11 +63,15 @@ def wrapper(impl):
6063
_op_partial_rules = {}
6164

6265

63-
def register_opschema_rule(op):
66+
def register_opschema_rule(ops):
6467
global _op_partial_rules
6568

6669
def wrapper(impl):
67-
_op_partial_rules[op] = impl
70+
if isinstance(ops, list):
71+
for op in ops:
72+
_op_partial_rules[op] = impl
73+
else:
74+
_op_partial_rules[ops] = impl
6875
return impl
6976

7077
return wrapper
@@ -97,7 +104,12 @@ def remove_invalid_configs(out_strat, mesh):
97104
output_specs = strategy.output_specs
98105
if isinstance(output_specs, DTensorSpec):
99106
output_specs = [output_specs]
100-
specs = list(strategy.input_specs) + list(output_specs)
107+
if strategy.input_specs is not None:
108+
specs = list(strategy.input_specs) + list(output_specs)
109+
else:
110+
# special case for ops like full, empty, which have no inputs. See further comments by `TENSOR_FACTORY_OPS`
111+
specs = list(output_specs)
112+
101113
for spec in specs:
102114
if spec is None:
103115
continue
@@ -335,22 +347,82 @@ def randperm_rule(mesh, specs):
335347
return OpStrategy([OpSpec(spec, input_specs=[spec], redistribute_cost=[[0.0]])])
336348

337349

338-
@register_rule(torch.ops.aten.full.default)
339-
def full_rule(mesh, specs):
340-
raise NotImplementedError("Needs hardening, only tested on a few cases")
341-
shape = specs[0]
342-
# TODO: get the dtype
343-
tensor_meta = _gen_tensor_meta(shape)
344-
# TODO: I'm hard-coding this here, I'll probably need to do something else about this
345-
placement = (Shard(0),) + (Replicate(),) * (mesh.ndim - 1)
346-
# placement = (Replicate(),) * mesh.ndim
347-
input_placement = (Replicate(),) * mesh.ndim
348-
spec = DTensorSpec(mesh, placement, tensor_meta=tensor_meta)
349-
input_spec = DTensorSpec(mesh, input_placement, tensor_meta=tensor_meta)
350-
# return OpStrategy([OpSpec(spec, input_specs=[spec], redistribute_cost=[[0.0]])])
351-
return OpStrategy(
352-
[OpSpec(spec, input_specs=[input_spec], redistribute_cost=[[0.0]])]
353-
)
350+
# We do a few special things for factory ops
351+
# - use the factory rule below
352+
# - fake that they have input schemas so the solver doesn't freak out
353+
# - convert their sizes to 'local tensor sizes' (divide by mesh dim) during ApplySharding
354+
TENSOR_FACTORY_OPS = [
355+
torch.ops.aten.zeros.default,
356+
torch.ops.aten.ones.default,
357+
torch.ops.aten.full.default,
358+
torch.ops.aten.empty.memory_format,
359+
torch.ops.aten.rand.default,
360+
torch.ops.aten.randn.default,
361+
]
362+
363+
364+
@register_opschema_rule(TENSOR_FACTORY_OPS)
365+
def factory_rule(mesh, op_schema: OpSchema) -> OpStrategy:
366+
"""
367+
This is an auto-parallel specific util that won't be upstreamed becuase of a UX mismatch.
368+
369+
In regular DTensor programs, a user has to either call `torch.full` to get a regular tensor, or
370+
`torch.distributed.tensor.full` (with placements specified) to get a DTensor.
371+
372+
There is no point registering a strategy in DTensor for factories like 'full' since there is no way they
373+
could be used by DTensor's dispatching logic. (Note: DTensor does provide strategies for similar ops like
374+
'new_full' and 'full_like', the difference being there is an input tensor to trigger dispatch off of and to
375+
use to direct the placement options.)
376+
377+
This util applies to any factory function that takes 'size' as the first argument,
378+
and supports Replication and Shard placements all at zero cost.
379+
"""
380+
shape = op_schema.args_schema[0]
381+
x = torch.empty(shape, device="meta")
382+
stride = x.stride()
383+
dtype = torch.get_default_dtype()
384+
if len(op_schema.args_schema) >= 3:
385+
# Todo didn't really verify this
386+
dtype = op_schema.args_schema[2]
387+
assert isinstance(dtype, torch.dtype), dtype
388+
389+
# TODO: ensure the solver knows that it is more expensive to Replicate factory functions than shard
390+
# for now, put replicate last since this might encourage sharding :?
391+
single_mesh_dim_strategies = [[Shard(i)] for i in range(len(shape))] + [
392+
[Replicate()]
393+
]
394+
395+
"""
396+
Expand the single_mesh_dim_strategies to full mesh dim strategies.
397+
see docs for `expand_to_full_mesh_op_strategy` in _tensor_ops.py in pytorch
398+
"""
399+
all_mesh_dim_strategies = [single_mesh_dim_strategies] * mesh.ndim
400+
401+
strategy_combs = list(itertools.product(*all_mesh_dim_strategies))
402+
403+
all_strategies = []
404+
for strategy_comb in strategy_combs:
405+
spec_list = [DTensorSpec(mesh, specs) for specs in zip(*strategy_comb)]
406+
output_specs = spec_list[0]
407+
output_specs.tensor_meta = TensorMeta(shape, stride, dtype)
408+
409+
if not is_tensor_shardable(shape, output_specs):
410+
continue
411+
412+
redistribute_cost = [
413+
# TODO: there shouldn't actually be a row here, since there is no input to the op and the rows correspond
414+
# to the inputs. However, the optimization code is not set up to tolerate input-less ops, so hack around it
415+
# (see "/data/users/whc/autoparallel/autoparallel/optimize_sharding.py", line 226, in walk_over_options)
416+
[0.0]
417+
* len(strategy_combs)
418+
]
419+
420+
strategy = OpSpec(
421+
output_specs=output_specs,
422+
redistribute_cost=redistribute_cost,
423+
)
424+
all_strategies.append(strategy)
425+
return OpStrategy(all_strategies)
354426

355427

356428
# ======================================
@@ -603,18 +675,19 @@ def index_put_rule(mesh, op_schema):
603675
t_strats = [DTensorSpec(mesh, placements=ispec.placements)]
604676
s = OpSpec(output_specs=ospec, input_specs=[ispec] + idxs_strats + t_strats)
605677

606-
redistribute_costs = (
607-
[generate_redistribute_costs(specs[0], ospec)]
608-
+ [
609-
generate_redistribute_costs(kk, idxs_strat)
610-
for kk, idxs_strat in zip(kspc, idxs_strats)
611-
]
612-
+ [generate_redistribute_costs(specs[2], t_strats[0])]
613-
)
614-
s.redistribute_cost = redistribute_costs
615-
res.append(s)
616-
out_strat = OpStrategy(res)
617-
return out_strat
678+
679+
# redistribute_costs = (
680+
# [generate_redistribute_costs(specs[0], ospec)]
681+
# + [
682+
# generate_redistribute_costs(kk, idxs_strat)
683+
# for kk, idxs_strat in zip(kspc, idxs_strats)
684+
# ]
685+
# + [generate_redistribute_costs(specs[2], t_strats[0])]
686+
# )
687+
# s.redistribute_cost = redistribute_costs
688+
# res.append(s)
689+
# out_strat = OpStrategy(res)
690+
# return out_strat
618691

619692

620693
@register_opschema_rule(torch.ops.aten._scaled_dot_product_efficient_attention.default)

autoparallel/utils.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,16 @@
33
# This source code is licensed under the BSD license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
from re import I
7+
68
import torch
79
from torch.distributed._tensor.placement_types import TensorMeta
810
from torch.distributed.device_mesh import _get_device_handle
911
from torch.distributed.tensor._op_schema import OpSchema, OpStrategy, TupleStrategy
1012
from torch.distributed.tensor._ops.utils import generate_redistribute_costs
1113
from torch.utils._pytree import tree_flatten, tree_map_only
1214

13-
from .propagation_rules import _op_partial_rules, _op_rules, remove_invalid_configs
15+
from .propagation_rules import _op_partial_rules, _op_rules, remove_invalid_configs, TENSOR_FACTORY_OPS
1416

1517

1618
def propagate_tensor_meta(op, user_args, user_kwargs, out_strat):
@@ -44,6 +46,15 @@ def propagate_tensor_meta(op, user_args, user_kwargs, out_strat):
4446
else:
4547
assert tm is None
4648
if strat.input_specs is None:
49+
if op in TENSOR_FACTORY_OPS:
50+
# there isn't an input spec bc the op has no input!
51+
# continue
52+
53+
# but index_put op insists on looking at 'input_specs' of its input, which seems absurd.
54+
# so just copy it for now and fix later
55+
strat.input_specs = (strat.output_specs,)
56+
continue
57+
4758
supported_ops = {
4859
torch.ops.prims.convert_element_type.default,
4960
torch.ops.aten.clone.default,

0 commit comments

Comments
 (0)