Skip to content

Commit c098c92

Browse files
committed
Pass to split all_gather prologue and reduce_scatter prologue from fsdp graph
stack-info: PR: #201, branch: IvanKobzarev/stack/9
1 parent f1887eb commit c098c92

File tree

2 files changed

+155
-0
lines changed

2 files changed

+155
-0
lines changed

autoparallel/pipeline/passes.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
#
3+
# This source code is licensed under the BSD license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import dataclasses
7+
from contextlib import contextmanager
8+
from functools import partial
9+
from typing import Any
10+
11+
import torch
12+
import torch.fx.node
13+
import torch.utils._pytree as pytree
14+
from torch._functorch._aot_autograd.descriptors import AOTOutput
15+
from torch._functorch.partitioners import _extract_graph_with_inputs_outputs
16+
from torch._inductor.fx_passes.bucketing import (
17+
is_all_gather_into_tensor,
18+
is_reduce_scatter_tensor,
19+
)
20+
21+
22+
@contextmanager
23+
def exclude_from_fx_side_effectful(exclude_vals: set[Any]):
24+
original_val = torch.fx.node._side_effectful_functions.copy()
25+
try:
26+
torch.fx.node._side_effectful_functions -= exclude_vals
27+
yield
28+
finally:
29+
torch.fx.node._side_effectful_functions.clear()
30+
torch.fx.node._side_effectful_functions.update(original_val)
31+
32+
33+
exclude_wait_from_fx_side_effectful = partial(
34+
exclude_from_fx_side_effectful,
35+
{
36+
torch.ops._c10d_functional.wait_tensor,
37+
torch.ops._c10d_functional.wait_tensor.default,
38+
},
39+
)
40+
41+
42+
@dataclasses.dataclass(frozen=True)
43+
class PrefetchOutput(AOTOutput):
44+
pass
45+
46+
47+
@dataclasses.dataclass(frozen=True)
48+
class EpilogueInput(AOTOutput):
49+
pass
50+
51+
52+
def split_fsdp_prefetch(
53+
g: torch.fx.Graph, stop_at_all_gather: bool = True
54+
) -> tuple[torch.fx.Graph, torch.fx.Graph]:
55+
g_ins = g.find_nodes(op="placeholder")
56+
prefetch_g_outs_map = []
57+
58+
for g_in in g_ins:
59+
n = g_in
60+
has_ag = False
61+
while True:
62+
if len(n.users) != 1:
63+
break
64+
user = next(iter(n.users))
65+
if len(user.all_input_nodes) > 1:
66+
break
67+
n = user
68+
if stop_at_all_gather and is_all_gather_into_tensor(n):
69+
has_ag = True
70+
w_n = next(iter(n.users))
71+
n = w_n
72+
break
73+
if stop_at_all_gather and not has_ag:
74+
prefetch_g_outs_map.append(g_in)
75+
else:
76+
prefetch_g_outs_map.append(n)
77+
78+
prefetch_g_outs = prefetch_g_outs_map
79+
prefetch_g_outs_descs: list[AOTOutput] = [
80+
PrefetchOutput() for _ in range(len(prefetch_g_outs))
81+
]
82+
g_outs = pytree.arg_tree_leaves(*(n.args for n in g.find_nodes(op="output")))
83+
g_outs_descs = pytree.arg_tree_leaves(
84+
next(iter(g.find_nodes(op="output"))).meta.get("desc", [None] * len(g_outs))
85+
)
86+
with exclude_wait_from_fx_side_effectful():
87+
prefetch_g = _extract_graph_with_inputs_outputs(
88+
g,
89+
g_ins,
90+
prefetch_g_outs,
91+
prefetch_g_outs_descs,
92+
)
93+
94+
main_g = _extract_graph_with_inputs_outputs(
95+
g,
96+
prefetch_g_outs,
97+
g_outs,
98+
g_outs_descs,
99+
)
100+
return prefetch_g, main_g
101+
102+
103+
def split_fsdp_reduce_scatters_epilogue(
104+
g: torch.fx.Graph,
105+
) -> tuple[torch.fx.Graph, torch.fx.Graph]:
106+
g_ins = g.find_nodes(op="placeholder")
107+
g_outs = pytree.arg_tree_leaves(*(n.args for n in g.find_nodes(op="output")))
108+
g_outs_descs = pytree.arg_tree_leaves(
109+
next(iter(g.find_nodes(op="output"))).meta.get("desc", [None] * len(g_outs))
110+
)
111+
112+
g_outs_map = []
113+
for g_out in g_outs:
114+
n = g_out
115+
has_rs = False
116+
while n is not None:
117+
if len(n.all_input_nodes) != 1:
118+
break
119+
n_in = n.all_input_nodes[0]
120+
if len(n_in.users) > 1:
121+
break
122+
prev_n = n
123+
n = n_in
124+
if is_reduce_scatter_tensor(prev_n):
125+
has_rs = True
126+
break
127+
if has_rs:
128+
g_outs_map.append(n)
129+
else:
130+
g_outs_map.append(g_out)
131+
132+
epi_g_ins = [n for n in g_outs_map if n is not None]
133+
epi_g_ins_descs: list[AOTOutput] = [EpilogueInput() for _ in range(len(epi_g_ins))]
134+
main_g = _extract_graph_with_inputs_outputs(
135+
g,
136+
g_ins,
137+
epi_g_ins,
138+
epi_g_ins_descs,
139+
)
140+
epi_g = _extract_graph_with_inputs_outputs(
141+
g,
142+
epi_g_ins,
143+
g_outs,
144+
g_outs_descs,
145+
)
146+
147+
return main_g, epi_g

examples/example_llama3.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,14 @@ def _pass(graph):
253253
print(f"Took {time.time() - t:.2f} s")
254254
parallel_mod = autop.apply_placement(sharding_placement)
255255

256+
test_split_fsdp_prefetch = True
257+
if test_split_fsdp_prefetch:
258+
gm = autop.parallel_gm
259+
g = gm.graph
260+
from autoparallel.pipeline.passes import split_fsdp_prefetch
261+
262+
prefetch_g, main_g = split_fsdp_prefetch(g)
263+
256264
# run weight init on our sharded DTensor params
257265
parallel_mod.to_empty(device="cuda")
258266
parallel_mod.init_weights()

0 commit comments

Comments
 (0)