3333 propagate_shape_and_sharding ,
3434 register_op_strategy_map ,
3535)
36- from torch .distributed .tensor ._ops .utils import generate_redistribute_costs
36+ from torch .distributed .tensor ._ops .utils import (
37+ generate_redistribute_costs ,
38+ is_tensor_shardable ,
39+ )
3740from torch .distributed .tensor .placement_types import Replicate , Shard
3841
3942# TODO: move this to PyTorch
@@ -60,11 +63,15 @@ def wrapper(impl):
6063_op_partial_rules = {}
6164
6265
63- def register_opschema_rule (op ):
66+ def register_opschema_rule (ops ):
6467 global _op_partial_rules
6568
6669 def wrapper (impl ):
67- _op_partial_rules [op ] = impl
70+ if isinstance (ops , list ):
71+ for op in ops :
72+ _op_partial_rules [op ] = impl
73+ else :
74+ _op_partial_rules [ops ] = impl
6875 return impl
6976
7077 return wrapper
@@ -97,7 +104,12 @@ def remove_invalid_configs(out_strat, mesh):
97104 output_specs = strategy .output_specs
98105 if isinstance (output_specs , DTensorSpec ):
99106 output_specs = [output_specs ]
100- specs = list (strategy .input_specs ) + list (output_specs )
107+ if strategy .input_specs is not None :
108+ specs = list (strategy .input_specs ) + list (output_specs )
109+ else :
110+ # special case for ops like full, empty, which have no inputs. See further comments by `TENSOR_FACTORY_OPS`
111+ specs = list (output_specs )
112+
101113 for spec in specs :
102114 if spec is None :
103115 continue
@@ -335,22 +347,84 @@ def randperm_rule(mesh, specs):
335347 return OpStrategy ([OpSpec (spec , input_specs = [spec ], redistribute_cost = [[0.0 ]])])
336348
337349
338- @register_rule (torch .ops .aten .full .default )
339- def full_rule (mesh , specs ):
340- raise NotImplementedError ("Needs hardening, only tested on a few cases" )
341- shape = specs [0 ]
342- # TODO: get the dtype
343- tensor_meta = _gen_tensor_meta (shape )
344- # TODO: I'm hard-coding this here, I'll probably need to do something else about this
345- placement = (Shard (0 ),) + (Replicate (),) * (mesh .ndim - 1 )
346- # placement = (Replicate(),) * mesh.ndim
347- input_placement = (Replicate (),) * mesh .ndim
348- spec = DTensorSpec (mesh , placement , tensor_meta = tensor_meta )
349- input_spec = DTensorSpec (mesh , input_placement , tensor_meta = tensor_meta )
350- # return OpStrategy([OpSpec(spec, input_specs=[spec], redistribute_cost=[[0.0]])])
351- return OpStrategy (
352- [OpSpec (spec , input_specs = [input_spec ], redistribute_cost = [[0.0 ]])]
353- )
350+ # We do a few special things for factory ops
351+ # - use the factory rule below
352+ # - fake that they have input schemas so the solver doesn't freak out
353+ # - convert their sizes to 'local tensor sizes' (divide by mesh dim) during ApplySharding
354+ TENSOR_FACTORY_OPS = [
355+ torch .ops .aten .zeros .default ,
356+ torch .ops .aten .ones .default ,
357+ torch .ops .aten .full .default ,
358+ torch .ops .aten .empty .memory_format ,
359+ torch .ops .aten .rand .default ,
360+ torch .ops .aten .randn .default ,
361+ ]
362+
363+
364+ @register_opschema_rule (TENSOR_FACTORY_OPS )
365+ def factory_rule (mesh , op_schema : OpSchema ) -> OpStrategy :
366+ """
367+ This is an auto-parallel specific util that won't be upstreamed becuase of a UX mismatch.
368+
369+ In regular DTensor programs, a user has to either call `torch.full` to get a regular tensor, or
370+ `torch.distributed.tensor.full` (with placements specified) to get a DTensor.
371+
372+ There is no point registering a strategy in DTensor for factories like 'full' since there is no way they
373+ could be used by DTensor's dispatching logic. (Note: DTensor does provide strategies for similar ops like
374+ 'new_full' and 'full_like', the difference being there is an input tensor to trigger dispatch off of and to
375+ use to direct the placement options.)
376+
377+ This util applies to any factory function that takes 'size' as the first argument,
378+ and supports Replication and Shard placements all at zero cost.
379+ """
380+ assert isinstance (op_schema .args_schema [0 ], torch .Size )
381+ shape = op_schema .args_schema [0 ]
382+ x = torch .empty (shape , device = "meta" )
383+ stride = x .stride ()
384+ dtype = torch .get_default_dtype ()
385+ if len (op_schema .args_schema ) >= 3 :
386+ assert isinstance (op_schema .args_schema [2 ], torch .dtype )
387+ dtype = op_schema .args_schema [2 ]
388+ assert isinstance (dtype , torch .dtype ), dtype
389+
390+ # TODO: ensure the solver knows that it is more expensive to Replicate factory functions than shard
391+ # for now, put replicate last since this might encourage sharding. (Experimentally it seemed so, but definitely
392+ # this is not a stable gaurantee and we should fix this properly.)
393+ single_mesh_dim_strategies = [[Shard (i )] for i in range (len (shape ))] + [
394+ [Replicate ()]
395+ ]
396+
397+ """
398+ Expand the single_mesh_dim_strategies to full mesh dim strategies.
399+ see docs for `expand_to_full_mesh_op_strategy` in _tensor_ops.py in pytorch
400+ """
401+ all_mesh_dim_strategies = [single_mesh_dim_strategies ] * mesh .ndim
402+
403+ strategy_combs = list (itertools .product (* all_mesh_dim_strategies ))
404+
405+ all_strategies = []
406+ for strategy_comb in strategy_combs :
407+ spec_list = [DTensorSpec (mesh , specs ) for specs in zip (* strategy_comb )]
408+ output_specs = spec_list [0 ]
409+ output_specs .tensor_meta = TensorMeta (shape , stride , dtype )
410+
411+ if not is_tensor_shardable (shape , output_specs ):
412+ continue
413+
414+ redistribute_cost = [
415+ # TODO: there shouldn't actually be a row here, since there is no input to the op and the rows correspond
416+ # to the inputs. However, the optimization code is not set up to tolerate input-less ops, so hack around it
417+ # (see "/data/users/whc/autoparallel/autoparallel/optimize_sharding.py", line 226, in walk_over_options)
418+ [0.0 ]
419+ * len (strategy_combs )
420+ ]
421+
422+ strategy = OpSpec (
423+ output_specs = output_specs ,
424+ redistribute_cost = redistribute_cost ,
425+ )
426+ all_strategies .append (strategy )
427+ return OpStrategy (all_strategies )
354428
355429
356430# ======================================
0 commit comments