3333 propagate_shape_and_sharding ,
3434 register_op_strategy_map ,
3535)
36- from torch .distributed .tensor ._ops .utils import generate_redistribute_costs
36+ from torch .distributed .tensor ._ops .utils import (
37+ generate_redistribute_costs ,
38+ is_tensor_shardable ,
39+ )
3740from torch .distributed .tensor .placement_types import Replicate , Shard
3841
3942# TODO: move this to PyTorch
@@ -60,11 +63,15 @@ def wrapper(impl):
6063_op_partial_rules = {}
6164
6265
63- def register_opschema_rule (op ):
66+ def register_opschema_rule (ops ):
6467 global _op_partial_rules
6568
6669 def wrapper (impl ):
67- _op_partial_rules [op ] = impl
70+ if isinstance (ops , list ):
71+ for op in ops :
72+ _op_partial_rules [op ] = impl
73+ else :
74+ _op_partial_rules [ops ] = impl
6875 return impl
6976
7077 return wrapper
@@ -97,7 +104,11 @@ def remove_invalid_configs(out_strat, mesh):
97104 output_specs = strategy .output_specs
98105 if isinstance (output_specs , DTensorSpec ):
99106 output_specs = [output_specs ]
100- specs = list (strategy .input_specs ) + list (output_specs )
107+ if strategy .input_specs is not None :
108+ specs = list (strategy .input_specs ) + list (output_specs )
109+ else :
110+ specs = list (output_specs )
111+
101112 for spec in specs :
102113 if spec is None :
103114 continue
@@ -335,22 +346,82 @@ def randperm_rule(mesh, specs):
335346 return OpStrategy ([OpSpec (spec , input_specs = [spec ], redistribute_cost = [[0.0 ]])])
336347
337348
338- @register_rule (torch .ops .aten .full .default )
339- def full_rule (mesh , specs ):
340- raise NotImplementedError ("Needs hardening, only tested on a few cases" )
341- shape = specs [0 ]
342- # TODO: get the dtype
343- tensor_meta = _gen_tensor_meta (shape )
344- # TODO: I'm hard-coding this here, I'll probably need to do something else about this
345- placement = (Shard (0 ),) + (Replicate (),) * (mesh .ndim - 1 )
346- # placement = (Replicate(),) * mesh.ndim
347- input_placement = (Replicate (),) * mesh .ndim
348- spec = DTensorSpec (mesh , placement , tensor_meta = tensor_meta )
349- input_spec = DTensorSpec (mesh , input_placement , tensor_meta = tensor_meta )
350- # return OpStrategy([OpSpec(spec, input_specs=[spec], redistribute_cost=[[0.0]])])
351- return OpStrategy (
352- [OpSpec (spec , input_specs = [input_spec ], redistribute_cost = [[0.0 ]])]
353- )
349+ # We do a few special things for factory ops
350+ # - use the factory rule below
351+ # - fake that they have input schemas so the solver doesn't freak out
352+ # - convert their sizes to 'local tensor sizes' (divide by mesh dim) during ApplySharding
353+ TENSOR_FACTORY_OPS = [
354+ torch .ops .aten .zeros .default ,
355+ torch .ops .aten .ones .default ,
356+ torch .ops .aten .full .default ,
357+ torch .ops .aten .empty .memory_format ,
358+ torch .ops .aten .rand .default ,
359+ torch .ops .aten .randn .default ,
360+ ]
361+
362+
363+ @register_opschema_rule (TENSOR_FACTORY_OPS )
364+ def factory_rule (mesh , op_schema : OpSchema ) -> OpStrategy :
365+ """
366+ This is an auto-parallel specific util that won't be upstreamed becuase of a UX mismatch.
367+
368+ In regular DTensor programs, a user has to either call `torch.full` to get a regular tensor, or
369+ `torch.distributed.tensor.full` (with placements specified) to get a DTensor.
370+
371+ There is no point registering a strategy in DTensor for factories like 'full' since there is no way they
372+ could be used by DTensor's dispatching logic. (Note: DTensor does provide strategies for similar ops like
373+ 'new_full' and 'full_like', the difference being there is an input tensor to trigger dispatch off of and to
374+ use to direct the placement options.)
375+
376+ This util applies to any factory function that takes 'size' as the first argument,
377+ and supports Replication and Shard placements all at zero cost.
378+ """
379+ shape = op_schema .args_schema [0 ]
380+ x = torch .empty (shape , device = "meta" )
381+ stride = x .stride ()
382+ dtype = torch .get_default_dtype ()
383+ if len (op_schema .args_schema ) >= 3 :
384+ # Todo didn't really verify this
385+ dtype = op_schema .args_schema [2 ]
386+ assert isinstance (dtype , torch .dtype ), dtype
387+
388+ # TODO: ensure the solver knows that it is more expensive to Replicate factory functions than shard
389+ # for now, put replicate last since this might encourage sharding :?
390+ single_mesh_dim_strategies = [[Shard (i )] for i in range (len (shape ))] + [
391+ [Replicate ()]
392+ ]
393+
394+ """
395+ Expand the single_mesh_dim_strategies to full mesh dim strategies.
396+ see docs for `expand_to_full_mesh_op_strategy` in _tensor_ops.py in pytorch
397+ """
398+ all_mesh_dim_strategies = [single_mesh_dim_strategies ] * mesh .ndim
399+
400+ strategy_combs = list (itertools .product (* all_mesh_dim_strategies ))
401+
402+ all_strategies = []
403+ for strategy_comb in strategy_combs :
404+ spec_list = [DTensorSpec (mesh , specs ) for specs in zip (* strategy_comb )]
405+ output_specs = spec_list [0 ]
406+ output_specs .tensor_meta = TensorMeta (shape , stride , dtype )
407+
408+ if not is_tensor_shardable (shape , output_specs ):
409+ continue
410+
411+ redistribute_cost = [
412+ # TODO: there shouldn't actually be a row here, since there is no input to the op and the rows correspond
413+ # to the inputs. However, the optimization code is not set up to tolerate input-less ops, so hack around it
414+ # (see "/data/users/whc/autoparallel/autoparallel/optimize_sharding.py", line 226, in walk_over_options)
415+ [0.0 ]
416+ * len (strategy_combs )
417+ ]
418+
419+ strategy = OpSpec (
420+ output_specs = output_specs ,
421+ redistribute_cost = redistribute_cost ,
422+ )
423+ all_strategies .append (strategy )
424+ return OpStrategy (all_strategies )
354425
355426
356427# ======================================
@@ -603,18 +674,19 @@ def index_put_rule(mesh, op_schema):
603674 t_strats = [DTensorSpec (mesh , placements = ispec .placements )]
604675 s = OpSpec (output_specs = ospec , input_specs = [ispec ] + idxs_strats + t_strats )
605676
606- redistribute_costs = (
607- [generate_redistribute_costs (specs [0 ], ospec )]
608- + [
609- generate_redistribute_costs (kk , idxs_strat )
610- for kk , idxs_strat in zip (kspc , idxs_strats )
611- ]
612- + [generate_redistribute_costs (specs [2 ], t_strats [0 ])]
613- )
614- s .redistribute_cost = redistribute_costs
615- res .append (s )
616- out_strat = OpStrategy (res )
617- return out_strat
677+
678+ # redistribute_costs = (
679+ # [generate_redistribute_costs(specs[0], ospec)]
680+ # + [
681+ # generate_redistribute_costs(kk, idxs_strat)
682+ # for kk, idxs_strat in zip(kspc, idxs_strats)
683+ # ]
684+ # + [generate_redistribute_costs(specs[2], t_strats[0])]
685+ # )
686+ # s.redistribute_cost = redistribute_costs
687+ # res.append(s)
688+ # out_strat = OpStrategy(res)
689+ # return out_strat
618690
619691
620692@register_opschema_rule (torch .ops .aten ._scaled_dot_product_efficient_attention .default )
0 commit comments