Skip to content

Commit 5d54740

Browse files
committed
[asynctp] Knobs to enable asynctp; Adding constraints to solver to
incentivize asynctp
1 parent cd27579 commit 5d54740

File tree

3 files changed

+136
-2
lines changed

3 files changed

+136
-2
lines changed

autoparallel/optimize_sharding.py

Lines changed: 121 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@
105105
from .propagation_rules import _create_all_options
106106
from .utils import get_local_map_placement_option, get_placement_options
107107

108+
aten = torch.ops.aten
109+
108110

109111
def _debug_node(node):
110112
def my_print(x):
@@ -557,11 +559,11 @@ def print_old(self):
557559
print(self.get_violated_constraints_log())
558560

559561
def get_log(self, colored=False):
560-
561562
from torch.fx.graph import _color_fns, _identity
562563

563564
opt = {}
564565
nodes = list(self.graph.nodes)
566+
log_shapes = False
565567
for x in self.res:
566568
opt.setdefault(nodes[x[0]], []).append(self.ds[x])
567569

@@ -601,10 +603,28 @@ def get_log(self, colored=False):
601603
code.insert(l_id, line)
602604
l_id += 1
603605
continue
606+
log_extra = ""
607+
if log_shapes:
608+
if (
609+
isinstance(node, torch.fx.Node)
610+
and "val" in node.meta
611+
and isinstance(node.meta["val"], torch.Tensor)
612+
):
613+
log_extra += "("
614+
for arg in node.args:
615+
if (
616+
isinstance(arg, torch.fx.Node)
617+
and "val" in arg.meta
618+
and isinstance(arg.meta["val"], torch.Tensor)
619+
):
620+
log_extra += str(list(arg.meta["val"].shape))
621+
log_extra += ") -> "
622+
log_extra += str(list(node.meta["val"].shape))
623+
log_extra += "\n"
604624
# LOL
605625
while not code[l_id].lstrip().startswith(repr(node)):
606626
l_id += 1
607-
code[l_id] += line
627+
code[l_id] = log_extra + code[l_id] + line
608628
l_id += 1
609629
code = "\n".join(code)
610630
total_cost = sum(self.ds[x]["cost"] for x in self.res)
@@ -642,6 +662,11 @@ def get_solution(self, verbose=False):
642662
# add their costs
643663
for x in self.ds.values():
644664
opt_target[x["va"]] += x["cost"]
665+
666+
async_tp_enabled = True
667+
for va, cost in self.add_asynctp_scores().items():
668+
opt_target[va] += cost
669+
645670
self.prob += pulp.lpSum([va * cost for va, cost in opt_target.items()])
646671

647672
# solver = pulp.HiGHS(msg=verbose)
@@ -886,6 +911,100 @@ def add_sharded_output_constraint(self, output_placements=None):
886911
"them from the graph to avoid aliasing."
887912
)
888913

914+
def add_asynctp_scores(self):
915+
# Encourage placements that enable asyncTP fusions:
916+
# -10% of comm_cost
917+
# 1. ag + mm: S(d) -> R, d <= mm.ndim - 2
918+
# 2. mm + rs: P -> S(d)
919+
# TODO: Filter out FSDP ag/rs that will not be asyncTPed
920+
def _get_transformations(src_spec, tgt_spec):
921+
# TODO: Use real transform preparation
922+
# For now just checking left to right
923+
src_pls = src_spec.placements
924+
tgt_pls = tgt_spec.placements
925+
transformations = []
926+
for src_pl, tgt_pl in zip(src_pls, tgt_pls):
927+
if src_pl == tgt_pl:
928+
continue
929+
transformations.append((src_pl, tgt_pl))
930+
return transformations
931+
932+
def _produces_asynctp_ag(src_spec, tgt_spec, mm_dim):
933+
# Check that the last transition will be S(dim) -> Replicate
934+
935+
transformations = _get_transformations(src_spec, tgt_spec)
936+
if len(transformations) == 0:
937+
return False
938+
last_t = transformations[-1]
939+
return (
940+
last_t[1].is_replicate()
941+
and last_t[0].is_shard()
942+
and last_t[0].dim < mm_dim - 1
943+
)
944+
945+
def _produces_asynctp_rs(src_spec, tgt_spec):
946+
# Check that the last transition will be P -> S(dim)
947+
transformations = _get_transformations(src_spec, tgt_spec)
948+
if len(transformations) == 0:
949+
return False
950+
last_t = transformations[-1]
951+
return last_t[0].is_partial() and last_t[1].is_shard()
952+
953+
va_cost_delta = defaultdict(int)
954+
strats = self.strats
955+
for s_i, (node, s) in enumerate(strats.items()):
956+
if not (node.op == "call_function" and node.target == aten.mm.default):
957+
continue
958+
mm_n = node
959+
# Incentivize ag+mm
960+
# ard0 of MM should be S(dim) -> R to have all_gather before mm
961+
a_n = node.args[0]
962+
mm_sts = s.strategies
963+
for mm_st_i, mm_st in enumerate(s.strategies):
964+
a_sts = strats[a_n].strategies
965+
mm_tgt_spec = mm_st.input_specs[0]
966+
mm_st_placements = mm_tgt_spec.placements
967+
for a_st_i, a_st in enumerate(a_sts):
968+
a_src_spec = a_st.output_spec
969+
a_st_placements = a_src_spec.placements
970+
# TODO: Is adding constraint to arg is enough or we need to follow the arg
971+
# ancestors and find the first sharding change?
972+
if _produces_asynctp_ag(
973+
a_src_spec, mm_tgt_spec, mm_n.meta["val"].ndim
974+
):
975+
# TODO: We want to to calculate the cost of specific AG, as it will be pipelined,
976+
# for now using just redistribution cost
977+
cost = mm_st.redistribute_cost[0][a_st_i]
978+
if cost == float("inf"):
979+
continue
980+
va = self.ds[(s_i, 0, mm_st_i, a_st_i)]["va"]
981+
va_cost_delta[va] += -0.2 * cost
982+
# mm+rs
983+
src_spec = mm_st.output_spec
984+
if len(mm_n.users) == 0:
985+
continue
986+
mm_user = next(iter(mm_n.users))
987+
mm_user_s_i = self.node_map[mm_user]
988+
mm_u_arg_mm_i = -1
989+
for i, arg in enumerate(mm_user.args):
990+
if arg == mm_n:
991+
mm_u_arg_mm_i = i
992+
assert mm_u_arg_mm_i != -1
993+
mm_user_sts = strats[mm_user].strategies
994+
for mm_u_st_i, mm_u_st in enumerate(mm_user_sts):
995+
if _produces_asynctp_rs(src_spec, mm_u_st.input_specs[mm_u_arg_mm_i]):
996+
# TODO: We want to to calculate the cost of specific RS, as it will be pipelined,
997+
# for now using just redistribution cost
998+
cost = mm_u_st.redistribute_cost[mm_u_arg_mm_i][mm_u_st_i]
999+
if cost == float("inf"):
1000+
continue
1001+
va = self.ds[(mm_user_s_i, mm_u_arg_mm_i, mm_st_i, mm_u_st_i)][
1002+
"va"
1003+
]
1004+
va_cost_delta[va] += -0.2 * cost
1005+
1006+
return va_cost_delta
1007+
8891008
def validate(self):
8901009
for node in self.graph.nodes:
8911010
if node.op != "call_function":

examples/example_llama3.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,15 @@ def add_tp_constraints(autop):
191191
if enable_manual_constraint and not use_1d_mesh:
192192
add_tp_constraints(autop)
193193

194+
enable_async_tp = False
195+
if enable_async_tp:
196+
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
197+
198+
enable_symm_mem_for_group(mesh["dp"].get_group().group_name)
199+
enable_symm_mem_for_group(mesh["tp"].get_group().group_name)
200+
torch._inductor.config._micro_pipeline_tp = True
201+
torch._inductor.config.reorder_for_compute_comm_overlap = False
202+
194203
t = time.time()
195204
sharding_placement = autop.optimize_placement()
196205
print(f"Took {time.time() - t:.2f} s")

mast/sweep.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,12 @@ def maybe_find_pulp(maybe_path: Optional[str] = None) -> Optional[str]:
123123
"--model.name=llama3",
124124
"--compile.enable",
125125
],
126+
"llama3_FSDP_tp_async_tp_compile": llama3_2d_common_opts
127+
+ [
128+
"--model.name=llama3",
129+
"--compile.enable",
130+
"--parallelism.enable_async_tensor_parallel",
131+
],
126132
"llama3_autop_2d_compile": llama3_2d_common_opts
127133
+ [
128134
"--model.name=llama3_auto_parallel",

0 commit comments

Comments
 (0)