105105from .propagation_rules import _create_all_options
106106from .utils import get_local_map_placement_option , get_placement_options
107107
108+ aten = torch .ops .aten
109+
108110
109111def _debug_node (node ):
110112 def my_print (x ):
@@ -128,7 +130,12 @@ def _get_next_name(name):
128130
129131class ShardingOptimizer :
130132 def __init__ (
131- self , gm , mesh , rescale_grad_comm_cost_for_mp = 1.0 , repeated_subgraphs = False
133+ self ,
134+ gm ,
135+ mesh ,
136+ rescale_grad_comm_cost_for_mp = 1.0 ,
137+ repeated_subgraphs = False ,
138+ enable_asynctp = False ,
132139 ):
133140 self .gm = gm
134141 self .graph = gm .graph
@@ -139,11 +146,13 @@ def __init__(
139146 self .strats = self .build_sharding_metadata ()
140147
141148 self .cluster_links = {}
149+ repeated_subgraphs = True
142150 if repeated_subgraphs :
143151 t = time .time ()
144152 clusters = get_identical_regions (self .gm .graph , self .strats )
145153 print (f"Found { len (clusters )} clusters in { time .time () - t :.2f} s" )
146154 self .create_cluster_links (clusters )
155+ self .enable_asynctp = enable_asynctp
147156
148157 # ds: Decision variables dictionary mapping (s_i, argi, ss, ii) -> ILP variable data
149158 # Each key represents a choice of input placement ii and output placement ss
@@ -556,11 +565,11 @@ def print_old(self):
556565 print (self .get_violated_constraints_log ())
557566
558567 def get_log (self , colored = False ):
559-
560568 from torch .fx .graph import _color_fns , _identity
561569
562570 opt = {}
563571 nodes = list (self .graph .nodes )
572+ log_shapes = False
564573 for x in self .res :
565574 opt .setdefault (nodes [x [0 ]], []).append (self .ds [x ])
566575
@@ -600,10 +609,28 @@ def get_log(self, colored=False):
600609 code .insert (l_id , line )
601610 l_id += 1
602611 continue
612+ log_extra = ""
613+ if log_shapes :
614+ if (
615+ isinstance (node , torch .fx .Node )
616+ and "val" in node .meta
617+ and isinstance (node .meta ["val" ], torch .Tensor )
618+ ):
619+ log_extra += "("
620+ for arg in node .args :
621+ if (
622+ isinstance (arg , torch .fx .Node )
623+ and "val" in arg .meta
624+ and isinstance (arg .meta ["val" ], torch .Tensor )
625+ ):
626+ log_extra += str (list (arg .meta ["val" ].shape ))
627+ log_extra += ") -> "
628+ log_extra += str (list (node .meta ["val" ].shape ))
629+ log_extra += "\n "
603630 # LOL
604631 while not code [l_id ].lstrip ().startswith (repr (node )):
605632 l_id += 1
606- code [l_id ] += line
633+ code [l_id ] = log_extra + code [ l_id ] + line
607634 l_id += 1
608635 code = "\n " .join (code )
609636 total_cost = sum (self .ds [x ]["cost" ] for x in self .res )
@@ -641,6 +668,11 @@ def get_solution(self, verbose=False):
641668 # add their costs
642669 for x in self .ds .values ():
643670 opt_target [x ["va" ]] += x ["cost" ]
671+
672+ if self .enable_asynctp :
673+ for va , cost in self .add_asynctp_scores ().items ():
674+ opt_target [va ] += cost
675+
644676 self .prob += pulp .lpSum ([va * cost for va , cost in opt_target .items ()])
645677
646678 # solver = pulp.HiGHS(msg=verbose)
@@ -885,6 +917,105 @@ def add_sharded_output_constraint(self, output_placements=None):
885917 "them from the graph to avoid aliasing."
886918 )
887919
920+ def add_asynctp_scores (self ):
921+ # Encourage placements that enable asyncTP fusions:
922+ # -X % of comm_cost
923+ # 1. ag + mm: S(d) -> R, d < mm.ndim - 1
924+ # 2. mm + rs: P -> S(d)
925+ # TODO1: Filter out FSDP ag/rs that will not be asyncTPed
926+ # TODO2: With AsyncTP we should have perf wins,
927+ # overlapping ((group_size - 1) / group_size) of communication
928+ # minus cost of decomposition.
929+ # For this we need to get group_size from the redistribution.
930+ def _get_transformations (src_spec , tgt_spec ):
931+ # TODO: Use real transform preparation
932+ # For now just checking left to right
933+ src_pls = src_spec .placements
934+ tgt_pls = tgt_spec .placements
935+ transformations = []
936+ for src_pl , tgt_pl in zip (src_pls , tgt_pls ):
937+ if src_pl == tgt_pl :
938+ continue
939+ transformations .append ((src_pl , tgt_pl ))
940+ return transformations
941+
942+ def _produces_asynctp_ag (src_spec , tgt_spec , mm_dim ):
943+ # Check that the last transition will be S(dim) -> Replicate
944+
945+ transformations = _get_transformations (src_spec , tgt_spec )
946+ if len (transformations ) == 0 :
947+ return False
948+ last_t = transformations [- 1 ]
949+ return (
950+ last_t [1 ].is_replicate ()
951+ and last_t [0 ].is_shard ()
952+ and last_t [0 ].dim < mm_dim - 1
953+ )
954+
955+ def _produces_asynctp_rs (src_spec , tgt_spec , mm_dim ):
956+ # Check that the last transition will be P -> S(dim)
957+ transformations = _get_transformations (src_spec , tgt_spec )
958+ if len (transformations ) == 0 :
959+ return False
960+ last_t = transformations [- 1 ]
961+ return last_t [0 ].is_partial () and last_t [1 ].is_shard ()
962+
963+ va_cost_delta = defaultdict (int )
964+ strats = self .strats
965+ for s_i , (node , s ) in enumerate (strats .items ()):
966+ if not (node .op == "call_function" and node .target == aten .mm .default ):
967+ continue
968+ mm_n = node
969+ # Incentivize ag+mm
970+ # ard0 of MM should be S(dim) -> R to have all_gather before mm
971+ a_n = node .args [0 ]
972+ mm_sts = s .strategies
973+ for mm_st_i , mm_st in enumerate (mm_sts ):
974+ a_sts = strats [a_n ].strategies
975+ mm_tgt_spec = mm_st .input_specs [0 ]
976+ for a_st_i , a_st in enumerate (a_sts ):
977+ a_src_spec = a_st .output_spec
978+ # TODO: Is adding constraint to arg is enough or we need to follow the arg
979+ # ancestors and find the first sharding change?
980+ if _produces_asynctp_ag (
981+ a_src_spec , mm_tgt_spec , mm_n .meta ["val" ].ndim
982+ ):
983+ # TODO: We want to to calculate the cost of specific AG, as it will be pipelined,
984+ # for now using just redistribution cost
985+ cost = mm_st .redistribute_cost [0 ][a_st_i ]
986+ if cost == float ("inf" ):
987+ continue
988+ va = self .ds [(s_i , 0 , mm_st_i , a_st_i )]["va" ]
989+ va_cost_delta [va ] += - 0.3 * cost
990+ # mm+rs
991+ src_spec = mm_st .output_spec
992+ if len (mm_n .users ) == 0 :
993+ continue
994+ mm_user = next (iter (mm_n .users ))
995+ mm_user_s_i = self .node_map [mm_user ]
996+ mm_u_arg_mm_i = - 1
997+ for i , arg in enumerate (mm_user .args ):
998+ if arg == mm_n :
999+ mm_u_arg_mm_i = i
1000+ assert mm_u_arg_mm_i != - 1
1001+ mm_user_sts = strats [mm_user ].strategies
1002+ for mm_u_st_i , mm_u_st in enumerate (mm_user_sts ):
1003+ if _produces_asynctp_rs (
1004+ src_spec ,
1005+ mm_u_st .input_specs [mm_u_arg_mm_i ],
1006+ mm_n .meta ["val" ].ndim ,
1007+ ):
1008+ # TODO: We want to to calculate the cost of specific RS, as it will be pipelined,
1009+ # for now using just redistribution cost
1010+ cost = mm_u_st .redistribute_cost [mm_u_arg_mm_i ][mm_u_st_i ]
1011+ if cost == float ("inf" ):
1012+ continue
1013+ key = (mm_user_s_i , mm_u_arg_mm_i , mm_u_st_i , mm_st_i )
1014+ va = self .ds [key ]["va" ]
1015+ va_cost_delta [va ] += - 0.3 * cost
1016+
1017+ return va_cost_delta
1018+
8881019 def validate (self ):
8891020 for node in self .graph .nodes :
8901021 if node .op != "call_function" :
0 commit comments