|
105 | 105 | from .propagation_rules import _create_all_options |
106 | 106 | from .utils import get_local_map_placement_option, get_placement_options |
107 | 107 |
|
| 108 | +aten = torch.ops.aten |
| 109 | + |
108 | 110 |
|
109 | 111 | def _debug_node(node): |
110 | 112 | def my_print(x): |
@@ -557,11 +559,11 @@ def print_old(self): |
557 | 559 | print(self.get_violated_constraints_log()) |
558 | 560 |
|
559 | 561 | def get_log(self, colored=False): |
560 | | - |
561 | 562 | from torch.fx.graph import _color_fns, _identity |
562 | 563 |
|
563 | 564 | opt = {} |
564 | 565 | nodes = list(self.graph.nodes) |
| 566 | + log_shapes = False |
565 | 567 | for x in self.res: |
566 | 568 | opt.setdefault(nodes[x[0]], []).append(self.ds[x]) |
567 | 569 |
|
@@ -601,10 +603,28 @@ def get_log(self, colored=False): |
601 | 603 | code.insert(l_id, line) |
602 | 604 | l_id += 1 |
603 | 605 | continue |
| 606 | + log_extra = "" |
| 607 | + if log_shapes: |
| 608 | + if ( |
| 609 | + isinstance(node, torch.fx.Node) |
| 610 | + and "val" in node.meta |
| 611 | + and isinstance(node.meta["val"], torch.Tensor) |
| 612 | + ): |
| 613 | + log_extra += "(" |
| 614 | + for arg in node.args: |
| 615 | + if ( |
| 616 | + isinstance(arg, torch.fx.Node) |
| 617 | + and "val" in arg.meta |
| 618 | + and isinstance(arg.meta["val"], torch.Tensor) |
| 619 | + ): |
| 620 | + log_extra += str(list(arg.meta["val"].shape)) |
| 621 | + log_extra += ") -> " |
| 622 | + log_extra += str(list(node.meta["val"].shape)) |
| 623 | + log_extra += "\n" |
604 | 624 | # LOL |
605 | 625 | while not code[l_id].lstrip().startswith(repr(node)): |
606 | 626 | l_id += 1 |
607 | | - code[l_id] += line |
| 627 | + code[l_id] = log_extra + code[l_id] + line |
608 | 628 | l_id += 1 |
609 | 629 | code = "\n".join(code) |
610 | 630 | total_cost = sum(self.ds[x]["cost"] for x in self.res) |
@@ -642,6 +662,11 @@ def get_solution(self, verbose=False): |
642 | 662 | # add their costs |
643 | 663 | for x in self.ds.values(): |
644 | 664 | opt_target[x["va"]] += x["cost"] |
| 665 | + |
| 666 | + async_tp_enabled = True |
| 667 | + for va, cost in self.add_asynctp_scores().items(): |
| 668 | + opt_target[va] += cost |
| 669 | + |
645 | 670 | self.prob += pulp.lpSum([va * cost for va, cost in opt_target.items()]) |
646 | 671 |
|
647 | 672 | # solver = pulp.HiGHS(msg=verbose) |
@@ -886,6 +911,100 @@ def add_sharded_output_constraint(self, output_placements=None): |
886 | 911 | "them from the graph to avoid aliasing." |
887 | 912 | ) |
888 | 913 |
|
| 914 | + def add_asynctp_scores(self): |
| 915 | + # Encourage placements that enable asyncTP fusions: |
| 916 | + # -10% of comm_cost |
| 917 | + # 1. ag + mm: S(d) -> R, d <= mm.ndim - 2 |
| 918 | + # 2. mm + rs: P -> S(d) |
| 919 | + # TODO: Filter out FSDP ag/rs that will not be asyncTPed |
| 920 | + def _get_transformations(src_spec, tgt_spec): |
| 921 | + # TODO: Use real transform preparation |
| 922 | + # For now just checking left to right |
| 923 | + src_pls = src_spec.placements |
| 924 | + tgt_pls = tgt_spec.placements |
| 925 | + transformations = [] |
| 926 | + for src_pl, tgt_pl in zip(src_pls, tgt_pls): |
| 927 | + if src_pl == tgt_pl: |
| 928 | + continue |
| 929 | + transformations.append((src_pl, tgt_pl)) |
| 930 | + return transformations |
| 931 | + |
| 932 | + def _produces_asynctp_ag(src_spec, tgt_spec, mm_dim): |
| 933 | + # Check that the last transition will be S(dim) -> Replicate |
| 934 | + |
| 935 | + transformations = _get_transformations(src_spec, tgt_spec) |
| 936 | + if len(transformations) == 0: |
| 937 | + return False |
| 938 | + last_t = transformations[-1] |
| 939 | + return ( |
| 940 | + last_t[1].is_replicate() |
| 941 | + and last_t[0].is_shard() |
| 942 | + and last_t[0].dim < mm_dim - 1 |
| 943 | + ) |
| 944 | + |
| 945 | + def _produces_asynctp_rs(src_spec, tgt_spec): |
| 946 | + # Check that the last transition will be P -> S(dim) |
| 947 | + transformations = _get_transformations(src_spec, tgt_spec) |
| 948 | + if len(transformations) == 0: |
| 949 | + return False |
| 950 | + last_t = transformations[-1] |
| 951 | + return last_t[0].is_partial() and last_t[1].is_shard() |
| 952 | + |
| 953 | + va_cost_delta = defaultdict(int) |
| 954 | + strats = self.strats |
| 955 | + for s_i, (node, s) in enumerate(strats.items()): |
| 956 | + if not (node.op == "call_function" and node.target == aten.mm.default): |
| 957 | + continue |
| 958 | + mm_n = node |
| 959 | + # Incentivize ag+mm |
| 960 | + # ard0 of MM should be S(dim) -> R to have all_gather before mm |
| 961 | + a_n = node.args[0] |
| 962 | + mm_sts = s.strategies |
| 963 | + for mm_st_i, mm_st in enumerate(s.strategies): |
| 964 | + a_sts = strats[a_n].strategies |
| 965 | + mm_tgt_spec = mm_st.input_specs[0] |
| 966 | + mm_st_placements = mm_tgt_spec.placements |
| 967 | + for a_st_i, a_st in enumerate(a_sts): |
| 968 | + a_src_spec = a_st.output_spec |
| 969 | + a_st_placements = a_src_spec.placements |
| 970 | + # TODO: Is adding constraint to arg is enough or we need to follow the arg |
| 971 | + # ancestors and find the first sharding change? |
| 972 | + if _produces_asynctp_ag( |
| 973 | + a_src_spec, mm_tgt_spec, mm_n.meta["val"].ndim |
| 974 | + ): |
| 975 | + # TODO: We want to to calculate the cost of specific AG, as it will be pipelined, |
| 976 | + # for now using just redistribution cost |
| 977 | + cost = mm_st.redistribute_cost[0][a_st_i] |
| 978 | + if cost == float("inf"): |
| 979 | + continue |
| 980 | + va = self.ds[(s_i, 0, mm_st_i, a_st_i)]["va"] |
| 981 | + va_cost_delta[va] += -0.2 * cost |
| 982 | + # mm+rs |
| 983 | + src_spec = mm_st.output_spec |
| 984 | + if len(mm_n.users) == 0: |
| 985 | + continue |
| 986 | + mm_user = next(iter(mm_n.users)) |
| 987 | + mm_user_s_i = self.node_map[mm_user] |
| 988 | + mm_u_arg_mm_i = -1 |
| 989 | + for i, arg in enumerate(mm_user.args): |
| 990 | + if arg == mm_n: |
| 991 | + mm_u_arg_mm_i = i |
| 992 | + assert mm_u_arg_mm_i != -1 |
| 993 | + mm_user_sts = strats[mm_user].strategies |
| 994 | + for mm_u_st_i, mm_u_st in enumerate(mm_user_sts): |
| 995 | + if _produces_asynctp_rs(src_spec, mm_u_st.input_specs[mm_u_arg_mm_i]): |
| 996 | + # TODO: We want to to calculate the cost of specific RS, as it will be pipelined, |
| 997 | + # for now using just redistribution cost |
| 998 | + cost = mm_u_st.redistribute_cost[mm_u_arg_mm_i][mm_u_st_i] |
| 999 | + if cost == float("inf"): |
| 1000 | + continue |
| 1001 | + va = self.ds[(mm_user_s_i, mm_u_arg_mm_i, mm_st_i, mm_u_st_i)][ |
| 1002 | + "va" |
| 1003 | + ] |
| 1004 | + va_cost_delta[va] += -0.2 * cost |
| 1005 | + |
| 1006 | + return va_cost_delta |
| 1007 | + |
889 | 1008 | def validate(self): |
890 | 1009 | for node in self.graph.nodes: |
891 | 1010 | if node.op != "call_function": |
|
0 commit comments