Skip to content

Commit 2f1452e

Browse files
Sanket Jayant Purandarexmfan
authored andcommitted
Graph Multiplex Pass Example
1 parent cb3059e commit 2f1452e

File tree

3 files changed

+114
-2
lines changed

3 files changed

+114
-2
lines changed
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import copy
2+
3+
import torch
4+
import torch.fx as fx
5+
6+
7+
def multiplex_fw_bw_graph(
8+
fw_gm: fx.GraphModule, bw_gm: fx.GraphModule
9+
) -> fx.GraphModule:
10+
"""
11+
Multiplexes forward and backward graphs into a single unified graph module.
12+
13+
This function combines a forward graph and a backward graph into one multiplexed
14+
graph by merging their nodes and outputs. The resulting graph has:
15+
- All placeholders from both forward and backward graphs (backward followed by forward)
16+
- All computation nodes from both graphs (backward followed by forward)
17+
- Combined outputs (backward outputs followed by forward outputs)
18+
19+
Args:
20+
fw_gm: The forward graph module containing the forward computation
21+
bw_gm: The backward graph module containing the backward computation
22+
23+
Returns:
24+
A multiplexed fx.GraphModule containing both forward and backward computations
25+
with backward outputs appearing before forward outputs
26+
27+
Note:
28+
The function preserves node metadata during the merging process.
29+
"""
30+
# Mapping to track correspondence between backward graph nodes and new nodes
31+
old_node_to_new_node: dict[torch.fx.Node, torch.fx.Node] = {}
32+
33+
# Start with a deep copy of the forward graph as the base
34+
multiplexed_gm = copy.deepcopy(fw_gm)
35+
36+
# Collect all placeholder nodes from the backward graph
37+
bw_placeholders = []
38+
for n in bw_gm.graph.nodes:
39+
if n.op == "placeholder":
40+
bw_placeholders.append(n)
41+
42+
# Insert backward placeholders at the beginning of the multiplexed graph
43+
# Reversed order ensures correct execution sequence
44+
with multiplexed_gm.graph.inserting_before():
45+
for n in reversed(bw_placeholders):
46+
new_placeholder = multiplexed_gm.graph.placeholder(n.name)
47+
new_placeholder.meta = n.meta
48+
new_placeholder.target = new_placeholder.name
49+
old_node_to_new_node[n] = new_placeholder
50+
51+
# Find the last placeholder and the output node in the multiplexed graph
52+
insert_point = None
53+
multiplexed_graph_op_node = None
54+
for n in multiplexed_gm.graph.nodes:
55+
if n.op == "placeholder":
56+
insert_point = n
57+
if n.op == "output":
58+
multiplexed_graph_op_node = n
59+
60+
# Copy all computation nodes from backward graph into multiplexed graph
61+
bw_graph_op_node = None
62+
for n in bw_gm.graph.nodes:
63+
if n.op == "placeholder":
64+
continue
65+
if n.op == "output":
66+
bw_graph_op_node = n
67+
continue
68+
with multiplexed_gm.graph.inserting_after(insert_point):
69+
# Copy node and remap its arguments using the node mapping
70+
new_node = multiplexed_gm.graph.node_copy(
71+
n, lambda x: old_node_to_new_node[x]
72+
)
73+
new_node.meta = n.meta
74+
old_node_to_new_node[n] = new_node
75+
insert_point = new_node
76+
77+
assert bw_graph_op_node is not None
78+
assert multiplexed_graph_op_node is not None
79+
80+
# Collect output arguments from backward graph, remapping to new nodes
81+
bw_op_node_args = [
82+
old_node_to_new_node[n] if n is not None else None
83+
for n in bw_graph_op_node.args[0]
84+
]
85+
86+
# Collect output arguments from forward graph
87+
fw_op_node_args = list(multiplexed_graph_op_node.args[0])
88+
89+
# Remove the old output node and create new combined output
90+
insert_point = multiplexed_graph_op_node.prev
91+
multiplexed_gm.graph.erase_node(multiplexed_graph_op_node)
92+
93+
# Create combined output with backward outputs first, then forward outputs
94+
with multiplexed_gm.graph.inserting_after(insert_point):
95+
multiplexed_gm.graph.output(bw_op_node_args + fw_op_node_args)
96+
97+
multiplexed_gm.graph.eliminate_dead_code()
98+
multiplexed_gm.graph.lint()
99+
multiplexed_gm.recompile()
100+
return multiplexed_gm

autoparallel/api.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -517,11 +517,14 @@ def apply_placement(self, sharding_placement=None):
517517
sharding_placement
518518
)
519519

520-
self.parallel_model_fn = parallel_model_fn = aot_compile_joint_with_descriptors(
520+
parallel_model_fn, fw_module, bw_module = aot_compile_joint_with_descriptors(
521521
self.joint_with_descriptors,
522522
fw_compiler=self.compiler_fn,
523523
bw_compiler=self.compiler_fn,
524524
)
525+
self.parallel_model_fn = parallel_model_fn
526+
self.fw_module = fw_module
527+
self.bw_module = bw_module
525528

526529
# TODO: this probably belongs in the AOTAutograd API
527530
# TODO: pytree handling

examples/example_llama3.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from torch.distributed.tensor.placement_types import Partial, Replicate, Shard
1212
from torch.testing._internal.distributed.fake_pg import FakeStore
1313

14+
from autoparallel._passes.graph_multiplex import multiplex_fw_bw_graph
1415
from autoparallel._testing.models.llama3 import Transformer, TransformerModelArgs
1516
from autoparallel.api import AutoParallel
1617
from autoparallel.auto_bucketing import (
@@ -57,7 +58,7 @@ def model_fn():
5758
if model_type == "8b":
5859
model_args = TransformerModelArgs(
5960
dim=4096,
60-
n_layers=32,
61+
n_layers=1,
6162
n_heads=32,
6263
n_kv_heads=8,
6364
ffn_dim_multiplier=1.3,
@@ -252,6 +253,14 @@ def _pass(graph):
252253
sharding_placement = autop.optimize_placement(verbose=True)
253254
print(f"Took {time.time() - t:.2f} s")
254255
parallel_mod = autop.apply_placement(sharding_placement)
256+
multiplex_graph = True
257+
if multiplex_graph:
258+
f_gm = autop.fw_module
259+
b_gm = autop.bw_module
260+
multiplexed_gm = multiplex_fw_bw_graph(f_gm, b_gm)
261+
print(f_gm.graph)
262+
print(b_gm.graph)
263+
print(multiplexed_gm.graph)
255264

256265
# run weight init on our sharded DTensor params
257266
parallel_mod.to_empty(device="cuda")

0 commit comments

Comments
 (0)