Skip to content

Commit 710e1a6

Browse files
author
Sanket Jayant Purandare
committed
Enabling split_dI_dW and split_fsdp_collectives passes
1 parent 83789ef commit 710e1a6

File tree

9 files changed

+858
-229
lines changed

9 files changed

+858
-229
lines changed

.github/workflows/test_cuda.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,4 @@ jobs:
4646
python examples/example_dcp.py
4747
python examples/example_local_map.py
4848
python examples/example_ds3_local_map.py
49-
python examples/example_pp_graph_partition.py
49+
python examples/example_pp_graph_passes.py

autoparallel/_passes/graph_partition.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,15 @@ def partition_joint_with_descriptors(
2929
fw_compiler: Callable = boxed_nop_preserve_node_meta,
3030
bw_compiler: Callable = boxed_nop_preserve_node_meta,
3131
) -> tuple[
32-
torch.fx.GraphModule, torch.fx.GraphModule, int, int, int, int, list[int], list[Any]
32+
torch.fx.GraphModule,
33+
torch.fx.GraphModule,
34+
int,
35+
int,
36+
int,
37+
int,
38+
int,
39+
list[int],
40+
list[Any],
3341
]:
3442
aot_state: AOTState = jd._aot_state
3543
aot_graph_capture: AOTGraphCapture = jd._aot_graph_capture
@@ -79,9 +87,11 @@ def partition_joint_with_descriptors(
7987
num_mutate_inputs = len(
8088
[x for x in fw_metadata.input_info if x.mutates_data or x.mutates_metadata]
8189
)
90+
num_params_buffers = aot_config.num_params_buffers
8291
return (
8392
fw_module,
8493
bw_module,
94+
num_params_buffers,
8595
num_user_outputs,
8696
num_mutate_inputs,
8797
num_fw_outs_saved_for_bw,

autoparallel/_passes/split_di_dw_graph.py

Lines changed: 211 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,29 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import copy
7+
import itertools
8+
import operator
79

10+
import sympy
11+
import torch
812
import torch.fx as fx
9-
from functorch.compile import default_partition
13+
from torch._functorch.partitioners import (
14+
SavedForBackwardsAOTOutput,
15+
_extract_fwd_bwd_outputs,
16+
_extract_graph_with_inputs_outputs,
17+
_is_backward_state,
18+
_is_bwd_seed_offset,
19+
_is_fwd_seed_offset,
20+
_is_primal,
21+
_remove_by_name,
22+
find_symbol_binding_fx_nodes,
23+
free_symbols,
24+
is_sym_node,
25+
is_symbol_binding_fx_node,
26+
)
27+
from torch.utils._ordered_set import OrderedSet
28+
29+
from autoparallel.apply_sharding import rename_placeholder_node
1030

1131
# we are running the default partitioner on the bw graph, which requires AC tags being removed.
1232
# At this stage we have already finished running AC anyway, since we have a bw graph
@@ -44,21 +64,203 @@ def reorder_output_grads(bw_gm, num_weight_gradients):
4464
return len(grad_inputs)
4565

4666

47-
# TODO: in theory we can infer num_weight_gradients from the graph metadata directly
67+
# This is a copy of the function used by the default partitioner,
68+
# which does *not* reorder symint activations.
69+
# This is reordering is needed by the custom autograd.Function in AOTDispatcher,
70+
# but isn't needed in our dI/dW splitting since there is no autograd in the loop.
71+
# TODO: provide a way to gt this behavior automatically out of the default partitioner
72+
def _extract_fwd_bwd_modules(
73+
joint_module: fx.GraphModule,
74+
saved_values: list[fx.Node],
75+
saved_sym_nodes: list[fx.Node],
76+
*,
77+
num_fwd_outputs: int,
78+
) -> tuple[fx.GraphModule, fx.GraphModule]:
79+
(
80+
fwd_outputs,
81+
bwd_outputs,
82+
fwd_outputs_descs,
83+
bwd_outputs_descs,
84+
) = _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
85+
placeholders = joint_module.graph.find_nodes(op="placeholder")
86+
primal_inputs = [*filter(_is_primal, placeholders)]
87+
fwd_seed_offset_inputs = [*filter(_is_fwd_seed_offset, placeholders)]
88+
bwd_seed_offset_inputs = [*filter(_is_bwd_seed_offset, placeholders)]
89+
backward_state_inputs = [*filter(_is_backward_state, placeholders)]
90+
91+
bwd_graph = _extract_graph_with_inputs_outputs(
92+
joint_module.graph,
93+
saved_values + saved_sym_nodes + bwd_seed_offset_inputs,
94+
bwd_outputs,
95+
bwd_outputs_descs,
96+
"backward",
97+
ignore_must_be_in_fw_bw=True,
98+
)
99+
100+
distributed_enabled = torch.distributed.is_available()
101+
102+
for node in bwd_graph.find_nodes(op="placeholder"):
103+
# This is to filter out saved values that don't actually end up being used by the backwards pass
104+
if not node.users:
105+
_remove_by_name(saved_values, node.name)
106+
_remove_by_name(saved_sym_nodes, node.name)
107+
# wait_tensor is a bit special: if we have a "dead activation" that is not used in the bw,
108+
# but this dead activation is actually a collective,
109+
# then the collective will generally by followed by a wait_tensor() call.
110+
# we need to peak one node further to see if this wait_tensor is dead as well.
111+
elif distributed_enabled and all(
112+
n.target is torch.ops._c10d_functional.wait_tensor.default
113+
and len(n.users) == 0
114+
for n in node.users
115+
):
116+
_remove_by_name(saved_values, node.name)
117+
_remove_by_name(saved_sym_nodes, node.name)
118+
elif _is_backward_state(node):
119+
# BackwardState is saved directly
120+
_remove_by_name(saved_values, node.name)
121+
assert backward_state_inputs
122+
123+
# Now that we have the finalized list of saved values, we need to ensure
124+
# we propagate all symbols which are referenced by backwards inputs.
125+
# These are not directly used in the graph but are required for downstream
126+
# sizevar assignment
127+
saved_symbols: OrderedSet[sympy.Symbol] = OrderedSet()
128+
saved_sym_nodes_binding = []
129+
saved_sym_nodes_derived = []
130+
131+
# Some symbols may already be bound in the directly saved_sym_nodes,
132+
# keep track of them so we don't re-bind them
133+
for node in saved_sym_nodes:
134+
symbol = is_symbol_binding_fx_node(node)
135+
if symbol:
136+
saved_symbols.add(symbol)
137+
saved_sym_nodes_binding.append(node)
138+
else:
139+
saved_sym_nodes_derived.append(node)
140+
141+
# Now go through all of the prospective backward inputs and track any
142+
# other symbols we need to bind
143+
symbol_bindings = find_symbol_binding_fx_nodes(joint_module.graph)
144+
for node in itertools.chain(saved_sym_nodes_derived, saved_values):
145+
if "val" not in node.meta:
146+
continue
147+
new_symbols = free_symbols(node.meta["val"]) - saved_symbols
148+
# NB: Deterministic order please!
149+
for s in sorted(new_symbols, key=lambda s: s.name):
150+
# NB: For well formed graphs, the symbol should always be present,
151+
# but we also have ways to produce ill-formed graphs, e.g., direct
152+
# make_fx usages, so don't choke in this case
153+
if s not in symbol_bindings:
154+
continue
155+
saved_sym_nodes_binding.append(symbol_bindings[s])
156+
saved_symbols |= new_symbols
157+
158+
# Update saved_sym_nodes that are now reordered to have all bindings at
159+
# front. This can also be used later on to figure out the position of saved
160+
# sym nodes in the output of fwd graph.
161+
saved_sym_nodes.clear()
162+
saved_sym_nodes.extend(saved_sym_nodes_binding + saved_sym_nodes_derived)
48163

164+
# Now, we re-generate the fwd/bwd graphs.
165+
# NB: This might increase compilation time, but I doubt it matters
166+
fwd_graph = _extract_graph_with_inputs_outputs(
167+
joint_module.graph,
168+
primal_inputs + fwd_seed_offset_inputs,
169+
fwd_outputs + saved_values + saved_sym_nodes,
170+
fwd_outputs_descs
171+
+ [
172+
SavedForBackwardsAOTOutput(i)
173+
for i in range(len(saved_values) + len(saved_sym_nodes))
174+
],
175+
"forward",
176+
ignore_must_be_in_fw_bw=True,
177+
)
178+
bwd_graph = _extract_graph_with_inputs_outputs(
179+
joint_module.graph,
180+
saved_values + saved_sym_nodes + bwd_seed_offset_inputs + backward_state_inputs,
181+
bwd_outputs,
182+
bwd_outputs_descs,
183+
"backward",
184+
ignore_must_be_in_fw_bw=True,
185+
)
186+
187+
fwd_module = fx._lazy_graph_module._make_graph_module(joint_module, fwd_graph)
188+
bwd_module = fx._lazy_graph_module._make_graph_module(joint_module, bwd_graph)
189+
return fwd_module, bwd_module
49190

191+
192+
# TODO: in theory we can infer num_weight_gradients from the graph metadata directly
50193
def split_di_dw_graph(
51-
bw_gm: fx.GraphModule, *, num_weight_gradients
52-
) -> tuple[fx.GraphModule, fx.GraphModule]:
194+
bw_gm_old: fx.GraphModule, *, num_weight_gradients: int
195+
) -> tuple[fx.GraphModule, fx.GraphModule, int]:
53196
# we could consider doing this is a non-mutating way
54-
bw_gm = copy.deepcopy(bw_gm)
197+
bw_gm = copy.deepcopy(bw_gm_old)
198+
placeholders = bw_gm.graph.find_nodes(op="placeholder")
199+
for p in placeholders:
200+
if p.name.startswith("tangent"):
201+
name_suffix = p.name[8:]
202+
rename_placeholder_node(bw_gm, p, f"not_tngnt{name_suffix}")
203+
55204
remove_recompute_tags(bw_gm)
56205
num_input_gradients = reorder_output_grads(bw_gm, num_weight_gradients)
57206
bw_gm.recompile()
58207

59-
args = [x.meta["val"] for x in bw_gm.graph.find_nodes(op="placeholder")]
208+
args = list(bw_gm.graph.find_nodes(op="placeholder"))
209+
210+
# bw_inputs, bw_weights = default_partition(bw_gm, args, num_fwd_outputs=num_input_gradients)
211+
# return bw_inputs, bw_weights, num_input_gradients
212+
213+
(
214+
grad_inps,
215+
grad_weights,
216+
grad_inp_descs,
217+
grad_weight_descs,
218+
) = _extract_fwd_bwd_outputs(bw_gm, num_fwd_outputs=num_input_gradients)
219+
bw_inputs_gm = _extract_graph_with_inputs_outputs(
220+
bw_gm.graph,
221+
args,
222+
grad_inps,
223+
grad_inp_descs,
224+
"forward",
225+
ignore_must_be_in_fw_bw=True,
226+
)
227+
bw_inputs_gm_node_names = OrderedSet(
228+
node.name for node in bw_inputs_gm.nodes if node.op != "output"
229+
)
230+
saved_values = []
231+
saved_sym_nodes = []
60232

61-
bw_inputs, bw_weights = default_partition(
62-
bw_gm, args, num_fwd_outputs=num_input_gradients
233+
for node in bw_gm.graph.nodes:
234+
if node.name not in bw_inputs_gm_node_names:
235+
# Not handling mutations for now,
236+
# we can try to re-use more of and/or consolidate with default partitioner
237+
continue
238+
if is_sym_node(node):
239+
saved_sym_nodes.append(node)
240+
elif (
241+
"tensor_meta" not in node.meta
242+
and node.op == "call_function"
243+
and not isinstance(node.meta.get("val"), torch._subclasses.FakeTensor)
244+
):
245+
users = node.users
246+
assert all(user.target == operator.getitem for user in users)
247+
saved_values.extend(users)
248+
else:
249+
backward_usages = [
250+
n for n in node.users if n.name not in bw_inputs_gm_node_names
251+
]
252+
if "tensor_meta" in node.meta and all(
253+
is_sym_node(n) for n in backward_usages
254+
):
255+
saved_sym_nodes.extend(backward_usages)
256+
else:
257+
saved_values.append(node)
258+
saved_values = list(dict.fromkeys(saved_values).keys())
259+
saved_sym_nodes = list(dict.fromkeys(saved_sym_nodes).keys())
260+
bw_inputs, bw_weights = _extract_fwd_bwd_modules(
261+
bw_gm,
262+
saved_values,
263+
saved_sym_nodes=saved_sym_nodes,
264+
num_fwd_outputs=num_input_gradients,
63265
)
64-
return bw_inputs, bw_weights
266+
return bw_inputs, bw_weights, num_input_gradients

0 commit comments

Comments
 (0)