Skip to content

Commit 9aebf3b

Browse files
PP Runner Full Prototype (#221)
1 parent 0510df2 commit 9aebf3b

File tree

13 files changed

+1423
-87
lines changed

13 files changed

+1423
-87
lines changed

.github/workflows/test_cuda.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,4 @@ jobs:
4646
python examples/example_dcp.py
4747
python examples/example_local_map.py
4848
python examples/example_ds3_local_map.py
49+
python examples/example_pp_graph_partition.py

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,6 @@
1010

1111
build/
1212
dist/
13+
tmp/
1314

1415
.vscode/

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ repos:
3737
- repo: local
3838
hooks:
3939
- id: mypy
40+
require_serial: true
4041
name: mypy
4142
entry: mypy
4243
language: system
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
#
3+
# This source code is licensed under the BSD license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import copy
7+
8+
import torch
9+
import torch.fx as fx
10+
11+
12+
def multiplex_fw_bw_graph(
13+
fw_gm: fx.GraphModule, bw_gm: fx.GraphModule
14+
) -> fx.GraphModule:
15+
"""
16+
Multiplexes forward and backward graphs into a single unified graph module.
17+
18+
This function combines a forward graph and a backward graph into one multiplexed
19+
graph by merging their nodes and outputs. The resulting graph has:
20+
- All placeholders from both forward and backward graphs (backward followed by forward)
21+
- All computation nodes from both graphs (backward followed by forward)
22+
- Combined outputs (backward outputs followed by forward outputs)
23+
24+
Args:
25+
fw_gm: The forward graph module containing the forward computation
26+
bw_gm: The backward graph module containing the backward computation
27+
28+
Returns:
29+
A multiplexed fx.GraphModule containing both forward and backward computations
30+
with backward outputs appearing before forward outputs
31+
32+
Note:
33+
The function preserves node metadata during the merging process.
34+
"""
35+
# Mapping to track correspondence between backward graph nodes and new nodes
36+
old_node_to_new_node: dict[torch.fx.Node, torch.fx.Node] = {}
37+
38+
# Start with a deep copy of the forward graph as the base
39+
multiplexed_gm = copy.deepcopy(fw_gm)
40+
41+
# Collect all placeholder nodes from the backward graph
42+
bw_placeholders = []
43+
for n in bw_gm.graph.nodes:
44+
if n.op == "placeholder":
45+
bw_placeholders.append(n)
46+
47+
# Insert backward placeholders at the beginning of the multiplexed graph
48+
# Reversed order ensures correct execution sequence
49+
with multiplexed_gm.graph.inserting_before():
50+
for n in reversed(bw_placeholders):
51+
new_placeholder = multiplexed_gm.graph.placeholder(n.name)
52+
new_placeholder.meta = n.meta
53+
new_placeholder.target = new_placeholder.name
54+
old_node_to_new_node[n] = new_placeholder
55+
56+
# Find the last placeholder and the output node in the multiplexed graph
57+
insert_point = None
58+
multiplexed_graph_op_node = None
59+
for n in multiplexed_gm.graph.nodes:
60+
if n.op == "placeholder":
61+
insert_point = n
62+
if n.op == "output":
63+
multiplexed_graph_op_node = n
64+
65+
# Copy all computation nodes from backward graph into multiplexed graph
66+
bw_graph_op_node = None
67+
for n in bw_gm.graph.nodes:
68+
if n.op == "placeholder":
69+
continue
70+
if n.op == "output":
71+
bw_graph_op_node = n
72+
continue
73+
with multiplexed_gm.graph.inserting_after(insert_point):
74+
# Copy node and remap its arguments using the node mapping
75+
new_node = multiplexed_gm.graph.node_copy(
76+
n, lambda x: old_node_to_new_node[x]
77+
)
78+
new_node.meta = n.meta
79+
old_node_to_new_node[n] = new_node
80+
insert_point = new_node
81+
82+
assert bw_graph_op_node is not None
83+
assert multiplexed_graph_op_node is not None
84+
85+
# Collect output arguments from backward graph, remapping to new nodes
86+
bw_op_node_args = [
87+
old_node_to_new_node[n] if n is not None else None
88+
for n in bw_graph_op_node.args[0]
89+
]
90+
91+
# Collect output arguments from forward graph
92+
fw_op_node_args = list(multiplexed_graph_op_node.args[0])
93+
94+
# Remove the old output node and create new combined output
95+
insert_point = multiplexed_graph_op_node.prev
96+
multiplexed_gm.graph.erase_node(multiplexed_graph_op_node)
97+
98+
# Create combined output with backward outputs first, then forward outputs
99+
with multiplexed_gm.graph.inserting_after(insert_point):
100+
multiplexed_gm.graph.output(bw_op_node_args + fw_op_node_args)
101+
102+
multiplexed_gm.graph.eliminate_dead_code()
103+
multiplexed_gm.graph.lint()
104+
multiplexed_gm.recompile()
105+
return multiplexed_gm

autoparallel/_passes/graph_partition.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,6 @@ def partition_joint_with_descriptors(
7979
num_mutate_inputs = len(
8080
[x for x in fw_metadata.input_info if x.mutates_data or x.mutates_metadata]
8181
)
82-
print(fw_module.graph)
83-
print(fw_module.graph)
84-
8582
return (
8683
fw_module,
8784
bw_module,
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
#
3+
# This source code is licensed under the BSD license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import copy
7+
8+
import torch.fx as fx
9+
from functorch.compile import default_partition
10+
11+
# we are running the default partitioner on the bw graph, which requires AC tags being removed.
12+
# At this stage we have already finished running AC anyway, since we have a bw graph
13+
14+
15+
def remove_recompute_tags(bw_gm):
16+
for n in bw_gm.graph.nodes:
17+
if "recompute" in n.meta:
18+
del n.meta["recompute"]
19+
20+
21+
# We are using the default partitioner to split our backward into dI and dW subgraphs.
22+
# We want to generate the dI subgraph *first*, because:
23+
# - in pipelining we generally want to schedule dI compute before dW
24+
# - the dI compute will potentially compute more activations that we need to plumb into dW compute
25+
# Today, the default partitioner requires that your split on the first K outputs of your combined graph.
26+
# So here, we reorder the outputs of the backward so grad_inputs are first.
27+
28+
29+
def reorder_output_grads(bw_gm, num_weight_gradients):
30+
outputs = bw_gm.graph.find_nodes(op="output")
31+
assert len(outputs) == 1
32+
output = outputs[0]
33+
assert isinstance(output.args[0], tuple)
34+
grad_weights, grad_inputs = (
35+
output.args[0][:num_weight_gradients],
36+
output.args[0][num_weight_gradients:],
37+
)
38+
new_out_tuple = grad_inputs + grad_weights
39+
with bw_gm.graph.inserting_after(output):
40+
# TODO: also set the new node's meta properly
41+
new_out = bw_gm.graph.output(new_out_tuple)
42+
output.replace_all_uses_with(new_out)
43+
bw_gm.graph.erase_node(output)
44+
return len(grad_inputs)
45+
46+
47+
# TODO: in theory we can infer num_weight_gradients from the graph metadata directly
48+
49+
50+
def split_di_dw_graph(
51+
bw_gm: fx.GraphModule, *, num_weight_gradients
52+
) -> tuple[fx.GraphModule, fx.GraphModule]:
53+
# we could consider doing this is a non-mutating way
54+
bw_gm = copy.deepcopy(bw_gm)
55+
remove_recompute_tags(bw_gm)
56+
num_input_gradients = reorder_output_grads(bw_gm, num_weight_gradients)
57+
bw_gm.recompile()
58+
59+
args = [x.meta["val"] for x in bw_gm.graph.find_nodes(op="placeholder")]
60+
61+
bw_inputs, bw_weights = default_partition(
62+
bw_gm, args, num_fwd_outputs=num_input_gradients
63+
)
64+
return bw_inputs, bw_weights
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
#
3+
# This source code is licensed under the BSD license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import dataclasses
7+
from contextlib import contextmanager
8+
from functools import partial
9+
from typing import Any
10+
11+
import torch
12+
import torch.fx.node
13+
import torch.utils._pytree as pytree
14+
from torch._functorch._aot_autograd.descriptors import AOTOutput
15+
from torch._functorch.partitioners import _extract_graph_with_inputs_outputs
16+
from torch._inductor.fx_passes.bucketing import (
17+
is_all_gather_into_tensor,
18+
is_reduce_scatter_tensor,
19+
)
20+
21+
22+
@contextmanager
23+
def exclude_from_fx_side_effectful(exclude_vals: set[Any]):
24+
original_val = torch.fx.node._side_effectful_functions.copy()
25+
try:
26+
torch.fx.node._side_effectful_functions -= exclude_vals
27+
yield
28+
finally:
29+
torch.fx.node._side_effectful_functions.clear()
30+
torch.fx.node._side_effectful_functions.update(original_val)
31+
32+
33+
exclude_wait_from_fx_side_effectful = partial(
34+
exclude_from_fx_side_effectful,
35+
{
36+
torch.ops._c10d_functional.wait_tensor,
37+
torch.ops._c10d_functional.wait_tensor.default,
38+
},
39+
)
40+
41+
42+
@dataclasses.dataclass(frozen=True)
43+
class PrefetchOutput(AOTOutput):
44+
pass
45+
46+
47+
@dataclasses.dataclass(frozen=True)
48+
class EpilogueInput(AOTOutput):
49+
pass
50+
51+
52+
def split_fsdp_prefetch(g: torch.fx.Graph) -> tuple[torch.fx.Graph, torch.fx.Graph]:
53+
g_ins = g.find_nodes(op="placeholder")
54+
prefetch_g_outs_map = []
55+
56+
for g_in in g_ins:
57+
n = g_in
58+
last_ag = None
59+
while True:
60+
if len(n.users) != 1:
61+
break
62+
user = next(iter(n.users))
63+
if len(user.all_input_nodes) > 1:
64+
break
65+
n = user
66+
if is_all_gather_into_tensor(n):
67+
last_ag = n
68+
if last_ag is None:
69+
prefetch_g_outs_map.append(g_in)
70+
else:
71+
w_n = next(iter(last_ag.users))
72+
prefetch_g_outs_map.append(w_n)
73+
74+
prefetch_g_outs = prefetch_g_outs_map
75+
prefetch_g_outs_descs: list[AOTOutput] = [
76+
PrefetchOutput() for _ in range(len(prefetch_g_outs))
77+
]
78+
g_outs = pytree.arg_tree_leaves(*(n.args for n in g.find_nodes(op="output")))
79+
g_outs_descs = pytree.arg_tree_leaves(
80+
next(iter(g.find_nodes(op="output"))).meta.get("desc", [None] * len(g_outs))
81+
)
82+
with exclude_wait_from_fx_side_effectful():
83+
prefetch_g = _extract_graph_with_inputs_outputs(
84+
g,
85+
g_ins,
86+
prefetch_g_outs,
87+
prefetch_g_outs_descs,
88+
ignore_must_be_in_fw_bw=True,
89+
)
90+
91+
main_g = _extract_graph_with_inputs_outputs(
92+
g,
93+
prefetch_g_outs,
94+
g_outs,
95+
g_outs_descs,
96+
ignore_must_be_in_fw_bw=True,
97+
)
98+
return prefetch_g, main_g
99+
100+
101+
def split_fsdp_reduce_scatters_epilogue(
102+
g: torch.fx.Graph,
103+
) -> tuple[torch.fx.Graph, torch.fx.Graph]:
104+
g_ins = g.find_nodes(op="placeholder")
105+
g_outs = pytree.arg_tree_leaves(*(n.args for n in g.find_nodes(op="output")))
106+
g_outs_descs = pytree.arg_tree_leaves(
107+
next(iter(g.find_nodes(op="output"))).meta.get("desc", [None] * len(g_outs))
108+
)
109+
110+
g_outs_map = []
111+
for g_out in g_outs:
112+
n = g_out
113+
last_rs = None
114+
while n is not None:
115+
if len(n.all_input_nodes) != 1:
116+
break
117+
n_in = n.all_input_nodes[0]
118+
if len(n_in.users) > 1:
119+
break
120+
prev_n = n
121+
n = n_in
122+
if is_reduce_scatter_tensor(prev_n):
123+
# In AP for mesh dim > 1
124+
# The reduction of gradients happen in multiple steps
125+
last_rs = n
126+
if last_rs is not None:
127+
g_outs_map.append(last_rs)
128+
else:
129+
g_outs_map.append(g_out)
130+
131+
epi_g_ins = [n for n in g_outs_map if n is not None]
132+
epi_g_ins_descs: list[AOTOutput] = [EpilogueInput() for _ in range(len(epi_g_ins))]
133+
134+
with exclude_wait_from_fx_side_effectful():
135+
main_g = _extract_graph_with_inputs_outputs(
136+
g,
137+
g_ins,
138+
epi_g_ins,
139+
epi_g_ins_descs,
140+
ignore_must_be_in_fw_bw=True,
141+
)
142+
epi_g = _extract_graph_with_inputs_outputs(
143+
g,
144+
epi_g_ins,
145+
g_outs,
146+
g_outs_descs,
147+
ignore_must_be_in_fw_bw=True,
148+
)
149+
150+
return main_g, epi_g

0 commit comments

Comments
 (0)