Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
2f1452e
Graph Multiplex Pass Example
Oct 16, 2025
b4a76e9
Adding split FSDP Collective Pass
Oct 16, 2025
8175160
Add split dI/dW graph pass example
bdhirsh Oct 29, 2025
e314df0
Pass to split all_gather prologue and reduce_scatter prologue from fs…
IvanKobzarev Oct 10, 2025
5c6dc37
add asserts for arg count before running boxed_run
xmfan Oct 29, 2025
da5b9f3
first split dsv3 into 8 stages then autoparallel the first stage
xmfan Oct 29, 2025
cda305f
tlparse fwd pp and bwd pp
xmfan Oct 29, 2025
51cb497
fix stage 1-7
xmfan Oct 29, 2025
cc61714
unpack fw outs of len 1
xmfan Oct 30, 2025
361d23f
Temporarily Revert "Graph Multiplex Pass Example"
xmfan Oct 30, 2025
37e8fce
FSDP prefetch/reduce scatter extraction passes don't work
xmfan Oct 30, 2025
a84d5d1
di/dw passes don't work either
xmfan Oct 30, 2025
17f8ac1
First draft PP runner
Oct 30, 2025
e22818e
trace and run all 8 stages one after the other
xmfan Oct 30, 2025
6885795
Add comments
xmfan Oct 30, 2025
ee737af
PP Runner Complete
Oct 30, 2025
34cb053
lintfix
ezyang Oct 30, 2025
97566c6
Merge remote-tracking branch 'origin/main' into war-oct29
ezyang Oct 30, 2025
226262f
Black format
ezyang Oct 30, 2025
21d00b3
tlparse markers
xmfan Oct 30, 2025
75a30d1
Try caching
fmassa Oct 30, 2025
6bc5064
Merge remote-tracking branch 'origin/fmassa/war_oct29' into war-oct29
ezyang Oct 30, 2025
003024f
Disable cache
ezyang Oct 30, 2025
ac90d51
support more arbitrary layers per stage
xmfan Oct 30, 2025
1a600bd
Make caching work
fmassa Oct 31, 2025
182a330
[ag/rs split passes] Clear partitioner_tag; take more than one ag/rs …
IvanKobzarev Oct 31, 2025
af77b95
fix test_graph_partition.py
xmfan Oct 31, 2025
e09b6d3
lint + move pp to bottom of api.py
xmfan Oct 31, 2025
a1e9768
remove unused examples/test scripts
xmfan Oct 31, 2025
50bd920
move pipeline stages to dsv3.py
xmfan Oct 31, 2025
6131e20
Refactor war branch: Improved caching, PP Runner now only deals with …
Nov 1, 2025
fab88c5
mypy requires serial for thread safety
xmfan Nov 3, 2025
cbc61e4
move graph partition to example and run it in ci
xmfan Nov 3, 2025
4acd7fd
Use pytorch partitioner, remove pipelines/utils.py fork (#223)
IvanKobzarev Nov 3, 2025
d1ac408
evaluate using fake tensors for ci
xmfan Nov 3, 2025
4b4fa08
fakes for example_pp_graph_partition.py
xmfan Nov 3, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/test_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,4 @@ jobs:
python examples/example_dcp.py
python examples/example_local_map.py
python examples/example_ds3_local_map.py
python examples/example_pp_graph_partition.py
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@

build/
dist/
tmp/

.vscode/
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ repos:
- repo: local
hooks:
- id: mypy
require_serial: true
name: mypy
entry: mypy
language: system
Expand Down
105 changes: 105 additions & 0 deletions autoparallel/_passes/graph_multiplex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import copy

import torch
import torch.fx as fx


def multiplex_fw_bw_graph(
fw_gm: fx.GraphModule, bw_gm: fx.GraphModule
) -> fx.GraphModule:
"""
Multiplexes forward and backward graphs into a single unified graph module.

This function combines a forward graph and a backward graph into one multiplexed
graph by merging their nodes and outputs. The resulting graph has:
- All placeholders from both forward and backward graphs (backward followed by forward)
- All computation nodes from both graphs (backward followed by forward)
- Combined outputs (backward outputs followed by forward outputs)

Args:
fw_gm: The forward graph module containing the forward computation
bw_gm: The backward graph module containing the backward computation

Returns:
A multiplexed fx.GraphModule containing both forward and backward computations
with backward outputs appearing before forward outputs

Note:
The function preserves node metadata during the merging process.
"""
# Mapping to track correspondence between backward graph nodes and new nodes
old_node_to_new_node: dict[torch.fx.Node, torch.fx.Node] = {}

# Start with a deep copy of the forward graph as the base
multiplexed_gm = copy.deepcopy(fw_gm)

# Collect all placeholder nodes from the backward graph
bw_placeholders = []
for n in bw_gm.graph.nodes:
if n.op == "placeholder":
bw_placeholders.append(n)

# Insert backward placeholders at the beginning of the multiplexed graph
# Reversed order ensures correct execution sequence
with multiplexed_gm.graph.inserting_before():
for n in reversed(bw_placeholders):
new_placeholder = multiplexed_gm.graph.placeholder(n.name)
new_placeholder.meta = n.meta
new_placeholder.target = new_placeholder.name
old_node_to_new_node[n] = new_placeholder

# Find the last placeholder and the output node in the multiplexed graph
insert_point = None
multiplexed_graph_op_node = None
for n in multiplexed_gm.graph.nodes:
if n.op == "placeholder":
insert_point = n
if n.op == "output":
multiplexed_graph_op_node = n

# Copy all computation nodes from backward graph into multiplexed graph
bw_graph_op_node = None
for n in bw_gm.graph.nodes:
if n.op == "placeholder":
continue
if n.op == "output":
bw_graph_op_node = n
continue
with multiplexed_gm.graph.inserting_after(insert_point):
# Copy node and remap its arguments using the node mapping
new_node = multiplexed_gm.graph.node_copy(
n, lambda x: old_node_to_new_node[x]
)
new_node.meta = n.meta
old_node_to_new_node[n] = new_node
insert_point = new_node

assert bw_graph_op_node is not None
assert multiplexed_graph_op_node is not None

# Collect output arguments from backward graph, remapping to new nodes
bw_op_node_args = [
old_node_to_new_node[n] if n is not None else None
for n in bw_graph_op_node.args[0]
]

# Collect output arguments from forward graph
fw_op_node_args = list(multiplexed_graph_op_node.args[0])

# Remove the old output node and create new combined output
insert_point = multiplexed_graph_op_node.prev
multiplexed_gm.graph.erase_node(multiplexed_graph_op_node)

# Create combined output with backward outputs first, then forward outputs
with multiplexed_gm.graph.inserting_after(insert_point):
multiplexed_gm.graph.output(bw_op_node_args + fw_op_node_args)

multiplexed_gm.graph.eliminate_dead_code()
multiplexed_gm.graph.lint()
multiplexed_gm.recompile()
return multiplexed_gm
3 changes: 0 additions & 3 deletions autoparallel/_passes/graph_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,6 @@ def partition_joint_with_descriptors(
num_mutate_inputs = len(
[x for x in fw_metadata.input_info if x.mutates_data or x.mutates_metadata]
)
print(fw_module.graph)
print(fw_module.graph)

return (
fw_module,
bw_module,
Expand Down
64 changes: 64 additions & 0 deletions autoparallel/_passes/split_di_dw_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import copy

import torch.fx as fx
from functorch.compile import default_partition

# we are running the default partitioner on the bw graph, which requires AC tags being removed.
# At this stage we have already finished running AC anyway, since we have a bw graph


def remove_recompute_tags(bw_gm):
for n in bw_gm.graph.nodes:
if "recompute" in n.meta:
del n.meta["recompute"]


# We are using the default partitioner to split our backward into dI and dW subgraphs.
# We want to generate the dI subgraph *first*, because:
# - in pipelining we generally want to schedule dI compute before dW
# - the dI compute will potentially compute more activations that we need to plumb into dW compute
# Today, the default partitioner requires that your split on the first K outputs of your combined graph.
# So here, we reorder the outputs of the backward so grad_inputs are first.


def reorder_output_grads(bw_gm, num_weight_gradients):
outputs = bw_gm.graph.find_nodes(op="output")
assert len(outputs) == 1
output = outputs[0]
assert isinstance(output.args[0], tuple)
grad_weights, grad_inputs = (
output.args[0][:num_weight_gradients],
output.args[0][num_weight_gradients:],
)
new_out_tuple = grad_inputs + grad_weights
with bw_gm.graph.inserting_after(output):
# TODO: also set the new node's meta properly
new_out = bw_gm.graph.output(new_out_tuple)
output.replace_all_uses_with(new_out)
bw_gm.graph.erase_node(output)
return len(grad_inputs)


# TODO: in theory we can infer num_weight_gradients from the graph metadata directly


def split_di_dw_graph(
bw_gm: fx.GraphModule, *, num_weight_gradients
) -> tuple[fx.GraphModule, fx.GraphModule]:
# we could consider doing this is a non-mutating way
bw_gm = copy.deepcopy(bw_gm)
remove_recompute_tags(bw_gm)
num_input_gradients = reorder_output_grads(bw_gm, num_weight_gradients)
bw_gm.recompile()

args = [x.meta["val"] for x in bw_gm.graph.find_nodes(op="placeholder")]

bw_inputs, bw_weights = default_partition(
bw_gm, args, num_fwd_outputs=num_input_gradients
)
return bw_inputs, bw_weights
150 changes: 150 additions & 0 deletions autoparallel/_passes/split_fsdp_collectives.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import dataclasses
from contextlib import contextmanager
from functools import partial
from typing import Any

import torch
import torch.fx.node
import torch.utils._pytree as pytree
from torch._functorch._aot_autograd.descriptors import AOTOutput
from torch._functorch.partitioners import _extract_graph_with_inputs_outputs
from torch._inductor.fx_passes.bucketing import (
is_all_gather_into_tensor,
is_reduce_scatter_tensor,
)


@contextmanager
def exclude_from_fx_side_effectful(exclude_vals: set[Any]):
original_val = torch.fx.node._side_effectful_functions.copy()
try:
torch.fx.node._side_effectful_functions -= exclude_vals
yield
finally:
torch.fx.node._side_effectful_functions.clear()
torch.fx.node._side_effectful_functions.update(original_val)


exclude_wait_from_fx_side_effectful = partial(
exclude_from_fx_side_effectful,
{
torch.ops._c10d_functional.wait_tensor,
torch.ops._c10d_functional.wait_tensor.default,
},
)


@dataclasses.dataclass(frozen=True)
class PrefetchOutput(AOTOutput):
pass


@dataclasses.dataclass(frozen=True)
class EpilogueInput(AOTOutput):
pass


def split_fsdp_prefetch(g: torch.fx.Graph) -> tuple[torch.fx.Graph, torch.fx.Graph]:
g_ins = g.find_nodes(op="placeholder")
prefetch_g_outs_map = []

for g_in in g_ins:
n = g_in
last_ag = None
while True:
if len(n.users) != 1:
break
user = next(iter(n.users))
if len(user.all_input_nodes) > 1:
break
n = user
if is_all_gather_into_tensor(n):
last_ag = n
if last_ag is None:
prefetch_g_outs_map.append(g_in)
else:
w_n = next(iter(last_ag.users))
prefetch_g_outs_map.append(w_n)

prefetch_g_outs = prefetch_g_outs_map
prefetch_g_outs_descs: list[AOTOutput] = [
PrefetchOutput() for _ in range(len(prefetch_g_outs))
]
g_outs = pytree.arg_tree_leaves(*(n.args for n in g.find_nodes(op="output")))
g_outs_descs = pytree.arg_tree_leaves(
next(iter(g.find_nodes(op="output"))).meta.get("desc", [None] * len(g_outs))
)
with exclude_wait_from_fx_side_effectful():
prefetch_g = _extract_graph_with_inputs_outputs(
g,
g_ins,
prefetch_g_outs,
prefetch_g_outs_descs,
ignore_must_be_in_fw_bw=True,
)

main_g = _extract_graph_with_inputs_outputs(
g,
prefetch_g_outs,
g_outs,
g_outs_descs,
ignore_must_be_in_fw_bw=True,
)
return prefetch_g, main_g


def split_fsdp_reduce_scatters_epilogue(
g: torch.fx.Graph,
) -> tuple[torch.fx.Graph, torch.fx.Graph]:
g_ins = g.find_nodes(op="placeholder")
g_outs = pytree.arg_tree_leaves(*(n.args for n in g.find_nodes(op="output")))
g_outs_descs = pytree.arg_tree_leaves(
next(iter(g.find_nodes(op="output"))).meta.get("desc", [None] * len(g_outs))
)

g_outs_map = []
for g_out in g_outs:
n = g_out
last_rs = None
while n is not None:
if len(n.all_input_nodes) != 1:
break
n_in = n.all_input_nodes[0]
if len(n_in.users) > 1:
break
prev_n = n
n = n_in
if is_reduce_scatter_tensor(prev_n):
# In AP for mesh dim > 1
# The reduction of gradients happen in multiple steps
last_rs = n
if last_rs is not None:
g_outs_map.append(last_rs)
else:
g_outs_map.append(g_out)

epi_g_ins = [n for n in g_outs_map if n is not None]
epi_g_ins_descs: list[AOTOutput] = [EpilogueInput() for _ in range(len(epi_g_ins))]

with exclude_wait_from_fx_side_effectful():
main_g = _extract_graph_with_inputs_outputs(
g,
g_ins,
epi_g_ins,
epi_g_ins_descs,
ignore_must_be_in_fw_bw=True,
)
epi_g = _extract_graph_with_inputs_outputs(
g,
epi_g_ins,
g_outs,
g_outs_descs,
ignore_must_be_in_fw_bw=True,
)

return main_g, epi_g
Loading