Skip to content

Commit a6dd593

Browse files
committed
Pass to split prefetch fsdp graph
stack-info: PR: #201, branch: IvanKobzarev/stack/9
1 parent f1887eb commit a6dd593

File tree

2 files changed

+60
-0
lines changed

2 files changed

+60
-0
lines changed

autoparallel/passes.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import dataclasses
2+
3+
import torch
4+
import torch.utils._pytree as pytree
5+
from torch._functorch._aot_autograd.descriptors import AOTOutput
6+
from torch._functorch.partitioners import _extract_graph_with_inputs_outputs
7+
8+
9+
@dataclasses.dataclass(frozen=True)
10+
class PrefetchOutput(AOTOutput):
11+
pass
12+
13+
14+
def split_fsdp_prefetch(g: torch.fx.Graph) -> tuple[torch.fx.Graph, torch.fx.Graph]:
15+
g_ins = g.find_nodes(op="placeholder")
16+
prefetch_g_outs_map = {}
17+
18+
for g_in in g_ins:
19+
n = g_in
20+
while True:
21+
if len(n.users) != 1:
22+
break
23+
user = next(iter(n.users))
24+
if len(user.all_input_nodes) > 1:
25+
break
26+
n = user
27+
prefetch_g_outs_map[g_in] = n
28+
29+
prefetch_g_outs = list(prefetch_g_outs_map.values())
30+
prefetch_g_outs_descs: list[AOTOutput] = [
31+
PrefetchOutput() for _ in range(len(prefetch_g_outs))
32+
]
33+
34+
prefetch_g = _extract_graph_with_inputs_outputs(
35+
g,
36+
g_ins,
37+
prefetch_g_outs,
38+
prefetch_g_outs_descs,
39+
)
40+
41+
g_outs = pytree.arg_tree_leaves(*(n.args for n in g.find_nodes(op="output")))
42+
g_outs_descs = pytree.arg_tree_leaves(
43+
next(iter(g.find_nodes(op="output"))).meta.get("desc", [None] * len(g_outs))
44+
)
45+
main_g = _extract_graph_with_inputs_outputs(
46+
g,
47+
prefetch_g_outs,
48+
g_outs,
49+
g_outs_descs,
50+
)
51+
return prefetch_g, main_g

examples/example_llama3.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ def add_tp_constraints(autop):
213213
with AutoParallel(
214214
model, input_fn, mesh, mp_policy, compile=True, repeated_subgraphs=True
215215
) as autop:
216+
216217
autop.add_parameter_memory_constraint(low=None, high=None)
217218

218219
x_sharding = (Shard(0),) + (Replicate(),) * (mesh.ndim - 1)
@@ -253,6 +254,14 @@ def _pass(graph):
253254
print(f"Took {time.time() - t:.2f} s")
254255
parallel_mod = autop.apply_placement(sharding_placement)
255256

257+
test_split_fsdp_prefetch = True
258+
if test_split_fsdp_prefetch:
259+
gm = autop.parallel_gm
260+
g = gm.graph
261+
from autoparallel.passes import split_fsdp_prefetch
262+
263+
prefetch_g, main_g = split_fsdp_prefetch(g)
264+
256265
# run weight init on our sharded DTensor params
257266
parallel_mod.to_empty(device="cuda")
258267
parallel_mod.init_weights()

0 commit comments

Comments
 (0)