Skip to content

Commit 3b9f914

Browse files
committed
DEBUG asynctp
stack-info: PR: #191, branch: IvanKobzarev/stack/8
1 parent 0ce9c0c commit 3b9f914

File tree

4 files changed

+195
-11
lines changed

4 files changed

+195
-11
lines changed

autoparallel/api.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ def __init__(
180180
enable_ac: bool = True,
181181
# None means 'auto'
182182
ac_stage_size_in_GiB: Optional[Union[float, str]] = "auto",
183+
enable_asynctp: bool = False,
183184
**kwargs,
184185
):
185186
self.stack = ExitStack()
@@ -210,6 +211,8 @@ def __init__(
210211
self.enable_ac = enable_ac
211212
self.ac_stage_size_in_GiB = ac_stage_size_in_GiB
212213

214+
self.enable_asynctp = enable_asynctp
215+
213216
# NB: rest of the construction happens in __enter__
214217
self.active = False
215218

@@ -236,6 +239,7 @@ def __enter__(self):
236239
self.mesh,
237240
rescale_grad_comm_cost_for_mp,
238241
repeated_subgraphs=self.kwargs.get("repeated_subgraphs", False),
242+
enable_asynctp=self.enable_asynctp,
239243
)
240244

241245
# makes sharding of params and gradients the same

autoparallel/optimize_sharding.py

Lines changed: 134 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@
105105
from .propagation_rules import _create_all_options
106106
from .utils import get_local_map_placement_option, get_placement_options
107107

108+
aten = torch.ops.aten
109+
108110

109111
def _debug_node(node):
110112
def my_print(x):
@@ -128,7 +130,12 @@ def _get_next_name(name):
128130

129131
class ShardingOptimizer:
130132
def __init__(
131-
self, gm, mesh, rescale_grad_comm_cost_for_mp=1.0, repeated_subgraphs=False
133+
self,
134+
gm,
135+
mesh,
136+
rescale_grad_comm_cost_for_mp=1.0,
137+
repeated_subgraphs=False,
138+
enable_asynctp=False,
132139
):
133140
self.gm = gm
134141
self.graph = gm.graph
@@ -139,11 +146,13 @@ def __init__(
139146
self.strats = self.build_sharding_metadata()
140147

141148
self.cluster_links = {}
149+
repeated_subgraphs = True
142150
if repeated_subgraphs:
143151
t = time.time()
144152
clusters = get_identical_regions(self.gm.graph, self.strats)
145153
print(f"Found {len(clusters)} clusters in {time.time() - t:.2f}s")
146154
self.create_cluster_links(clusters)
155+
self.enable_asynctp = enable_asynctp
147156

148157
# ds: Decision variables dictionary mapping (s_i, argi, ss, ii) -> ILP variable data
149158
# Each key represents a choice of input placement ii and output placement ss
@@ -556,11 +565,11 @@ def print_old(self):
556565
print(self.get_violated_constraints_log())
557566

558567
def get_log(self, colored=False):
559-
560568
from torch.fx.graph import _color_fns, _identity
561569

562570
opt = {}
563571
nodes = list(self.graph.nodes)
572+
log_shapes = False
564573
for x in self.res:
565574
opt.setdefault(nodes[x[0]], []).append(self.ds[x])
566575

@@ -600,10 +609,28 @@ def get_log(self, colored=False):
600609
code.insert(l_id, line)
601610
l_id += 1
602611
continue
612+
log_extra = ""
613+
if log_shapes:
614+
if (
615+
isinstance(node, torch.fx.Node)
616+
and "val" in node.meta
617+
and isinstance(node.meta["val"], torch.Tensor)
618+
):
619+
log_extra += "("
620+
for arg in node.args:
621+
if (
622+
isinstance(arg, torch.fx.Node)
623+
and "val" in arg.meta
624+
and isinstance(arg.meta["val"], torch.Tensor)
625+
):
626+
log_extra += str(list(arg.meta["val"].shape))
627+
log_extra += ") -> "
628+
log_extra += str(list(node.meta["val"].shape))
629+
log_extra += "\n"
603630
# LOL
604631
while not code[l_id].lstrip().startswith(repr(node)):
605632
l_id += 1
606-
code[l_id] += line
633+
code[l_id] = log_extra + code[l_id] + line
607634
l_id += 1
608635
code = "\n".join(code)
609636
total_cost = sum(self.ds[x]["cost"] for x in self.res)
@@ -641,6 +668,11 @@ def get_solution(self, verbose=False):
641668
# add their costs
642669
for x in self.ds.values():
643670
opt_target[x["va"]] += x["cost"]
671+
672+
if self.enable_asynctp:
673+
for va, cost in self.add_asynctp_scores().items():
674+
opt_target[va] += cost
675+
644676
self.prob += pulp.lpSum([va * cost for va, cost in opt_target.items()])
645677

646678
# solver = pulp.HiGHS(msg=verbose)
@@ -885,6 +917,105 @@ def add_sharded_output_constraint(self, output_placements=None):
885917
"them from the graph to avoid aliasing."
886918
)
887919

920+
def add_asynctp_scores(self):
921+
# Encourage placements that enable asyncTP fusions:
922+
# -X % of comm_cost
923+
# 1. ag + mm: S(d) -> R, d < mm.ndim - 1
924+
# 2. mm + rs: P -> S(d)
925+
# TODO1: Filter out FSDP ag/rs that will not be asyncTPed
926+
# TODO2: With AsyncTP we should have perf wins,
927+
# overlapping ((group_size - 1) / group_size) of communication
928+
# minus cost of decomposition.
929+
# For this we need to get group_size from the redistribution.
930+
def _get_transformations(src_spec, tgt_spec):
931+
# TODO: Use real transform preparation
932+
# For now just checking left to right
933+
src_pls = src_spec.placements
934+
tgt_pls = tgt_spec.placements
935+
transformations = []
936+
for src_pl, tgt_pl in zip(src_pls, tgt_pls):
937+
if src_pl == tgt_pl:
938+
continue
939+
transformations.append((src_pl, tgt_pl))
940+
return transformations
941+
942+
def _produces_asynctp_ag(src_spec, tgt_spec, mm_dim):
943+
# Check that the last transition will be S(dim) -> Replicate
944+
945+
transformations = _get_transformations(src_spec, tgt_spec)
946+
if len(transformations) == 0:
947+
return False
948+
last_t = transformations[-1]
949+
return (
950+
last_t[1].is_replicate()
951+
and last_t[0].is_shard()
952+
and last_t[0].dim < mm_dim - 1
953+
)
954+
955+
def _produces_asynctp_rs(src_spec, tgt_spec, mm_dim):
956+
# Check that the last transition will be P -> S(dim)
957+
transformations = _get_transformations(src_spec, tgt_spec)
958+
if len(transformations) == 0:
959+
return False
960+
last_t = transformations[-1]
961+
return last_t[0].is_partial() and last_t[1].is_shard()
962+
963+
va_cost_delta = defaultdict(int)
964+
strats = self.strats
965+
for s_i, (node, s) in enumerate(strats.items()):
966+
if not (node.op == "call_function" and node.target == aten.mm.default):
967+
continue
968+
mm_n = node
969+
# Incentivize ag+mm
970+
# ard0 of MM should be S(dim) -> R to have all_gather before mm
971+
a_n = node.args[0]
972+
mm_sts = s.strategies
973+
for mm_st_i, mm_st in enumerate(mm_sts):
974+
a_sts = strats[a_n].strategies
975+
mm_tgt_spec = mm_st.input_specs[0]
976+
for a_st_i, a_st in enumerate(a_sts):
977+
a_src_spec = a_st.output_spec
978+
# TODO: Is adding constraint to arg is enough or we need to follow the arg
979+
# ancestors and find the first sharding change?
980+
if _produces_asynctp_ag(
981+
a_src_spec, mm_tgt_spec, mm_n.meta["val"].ndim
982+
):
983+
# TODO: We want to to calculate the cost of specific AG, as it will be pipelined,
984+
# for now using just redistribution cost
985+
cost = mm_st.redistribute_cost[0][a_st_i]
986+
if cost == float("inf"):
987+
continue
988+
va = self.ds[(s_i, 0, mm_st_i, a_st_i)]["va"]
989+
va_cost_delta[va] += -0.3 * cost
990+
# mm+rs
991+
src_spec = mm_st.output_spec
992+
if len(mm_n.users) == 0:
993+
continue
994+
mm_user = next(iter(mm_n.users))
995+
mm_user_s_i = self.node_map[mm_user]
996+
mm_u_arg_mm_i = -1
997+
for i, arg in enumerate(mm_user.args):
998+
if arg == mm_n:
999+
mm_u_arg_mm_i = i
1000+
assert mm_u_arg_mm_i != -1
1001+
mm_user_sts = strats[mm_user].strategies
1002+
for mm_u_st_i, mm_u_st in enumerate(mm_user_sts):
1003+
if _produces_asynctp_rs(
1004+
src_spec,
1005+
mm_u_st.input_specs[mm_u_arg_mm_i],
1006+
mm_n.meta["val"].ndim,
1007+
):
1008+
# TODO: We want to to calculate the cost of specific RS, as it will be pipelined,
1009+
# for now using just redistribution cost
1010+
cost = mm_u_st.redistribute_cost[mm_u_arg_mm_i][mm_u_st_i]
1011+
if cost == float("inf"):
1012+
continue
1013+
key = (mm_user_s_i, mm_u_arg_mm_i, mm_u_st_i, mm_st_i)
1014+
va = self.ds[key]["va"]
1015+
va_cost_delta[va] += -0.3 * cost
1016+
1017+
return va_cost_delta
1018+
8881019
def validate(self):
8891020
for node in self.graph.nodes:
8901021
if node.op != "call_function":

examples/example_llama3.py

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,13 @@ def add_tp_constraints(autop):
211211

212212
# parallelize the model
213213
with AutoParallel(
214-
model, input_fn, mesh, mp_policy, compile=True, repeated_subgraphs=True
214+
model,
215+
input_fn,
216+
mesh,
217+
mp_policy,
218+
compile=True,
219+
repeated_subgraphs=True,
220+
enable_asynctp=enable_asynctp,
215221
) as autop:
216222
autop.add_parameter_memory_constraint(low=None, high=None)
217223

@@ -229,22 +235,59 @@ def add_tp_constraints(autop):
229235
if enable_manual_constraint and not use_1d_mesh:
230236
add_tp_constraints(autop)
231237

232-
if enable_asynctp:
233-
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
238+
enable_overlap_scheduling = True
239+
enable_overlap_scheduling_bucketing = True
240+
if enable_overlap_scheduling_bucketing:
241+
assert (
242+
enable_overlap_scheduling
243+
), "bucketing can not be used without overlap scheduling"
244+
enable_asynctp = True
245+
from autoparallel.asynctp import (
246+
_micro_pipeline_tp_ag_mm_last_dim_enabled,
247+
_micro_pipeline_tp_ag_transpose_mm_enabled,
248+
)
249+
250+
_micro_pipeline_tp_ag_transpose_mm_enabled = True
251+
_micro_pipeline_tp_ag_mm_last_dim_enabled = True
252+
if (
253+
enable_overlap_scheduling
254+
or enable_overlap_scheduling_bucketing
255+
or enable_asynctp
256+
):
257+
torch._inductor.config.reorder_for_peak_memory = False
258+
torch._inductor.config.reorder_for_compute_comm_overlap = False
259+
torch._inductor.config.allow_buffer_reuse = False
260+
torch._inductor.config.test_configs.aten_fx_overlap_preserving_bucketing = (
261+
enable_overlap_scheduling_bucketing
262+
)
234263

235-
enable_symm_mem_for_group(mesh["dp"].get_group().group_name)
236-
enable_symm_mem_for_group(mesh["tp"].get_group().group_name)
237-
torch._inductor.config._micro_pipeline_tp = False
238-
from autoparallel.asynctp import micro_pipeline_tp_pass
264+
if enable_asynctp:
265+
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
266+
267+
enable_symm_mem_for_group(mesh["tp"].get_group().group_name)
268+
enable_symm_mem_for_group(mesh["dp"].get_group().group_name)
269+
torch._inductor.config._micro_pipeline_tp = False
270+
# Disable inductor AsyncTP passes, in favor of using Autoparallel passes fork.
271+
# TODO: Switch to Inductor AsyncTP passes, when all additions landed.
272+
from autoparallel.asynctp import micro_pipeline_tp_pass
239273

240274
existing_post_grad_custom_post_pass = (
241275
torch._inductor.config.post_grad_custom_post_pass
242276
)
277+
from torch._inductor.fx_passes.overlap_scheduling import OverlapScheduler
243278

244279
def _pass(graph):
245280
if existing_post_grad_custom_post_pass is not None:
246281
existing_post_grad_custom_post_pass(graph)
247-
micro_pipeline_tp_pass(graph)
282+
283+
collective_info = None
284+
if enable_overlap_scheduling:
285+
overlap_scheduler = OverlapScheduler(graph.owning_module)
286+
overlap_scheduler.run()
287+
collective_info = overlap_scheduler.collective_info
288+
289+
if enable_asynctp:
290+
micro_pipeline_tp_pass(graph, collective_info)
248291

249292
torch._inductor.config.post_grad_custom_post_pass = _pass
250293

mast/sweep.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,12 @@ def maybe_find_pulp(maybe_path: Optional[str] = None) -> Optional[str]:
123123
"--model.name=llama3",
124124
"--compile.enable",
125125
],
126+
"llama3_FSDP_tp_async_tp_compile": llama3_2d_common_opts
127+
+ [
128+
"--model.name=llama3",
129+
"--compile.enable",
130+
"--parallelism.enable_async_tensor_parallel",
131+
],
126132
"llama3_autop_2d_compile": llama3_2d_common_opts
127133
+ [
128134
"--model.name=llama3_auto_parallel",

0 commit comments

Comments
 (0)