Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions autoparallel/_passes/graph_multiplex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only real changes in this PR are example_llama3_di_dw.py and split_di_dw_graph.py, the rest is an artifact of me working on top of #205 and making the PR against main. I can move this graph pass somewhere else and/or wait for that other PR to land as necessary

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say that it would be good to start landing things instead of working from branches of branches.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed. I was prioritizing "having a graph pass that Sanket can dump into his pipeline runtime" over having a PR that is ready to land. I'm happy to clean up this example to make it more self contained + make it a proper test.

#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import copy

import torch
import torch.fx as fx


def multiplex_fw_bw_graph(
fw_gm: fx.GraphModule, bw_gm: fx.GraphModule
) -> fx.GraphModule:
"""
Multiplexes forward and backward graphs into a single unified graph module.

This function combines a forward graph and a backward graph into one multiplexed
graph by merging their nodes and outputs. The resulting graph has:
- All placeholders from both forward and backward graphs (backward followed by forward)
- All computation nodes from both graphs (backward followed by forward)
- Combined outputs (backward outputs followed by forward outputs)

Args:
fw_gm: The forward graph module containing the forward computation
bw_gm: The backward graph module containing the backward computation

Returns:
A multiplexed fx.GraphModule containing both forward and backward computations
with backward outputs appearing before forward outputs

Note:
The function preserves node metadata during the merging process.
"""
# Mapping to track correspondence between backward graph nodes and new nodes
old_node_to_new_node: dict[torch.fx.Node, torch.fx.Node] = {}

# Start with a deep copy of the forward graph as the base
multiplexed_gm = copy.deepcopy(fw_gm)

# Collect all placeholder nodes from the backward graph
bw_placeholders = []
for n in bw_gm.graph.nodes:
if n.op == "placeholder":
bw_placeholders.append(n)

# Insert backward placeholders at the beginning of the multiplexed graph
# Reversed order ensures correct execution sequence
with multiplexed_gm.graph.inserting_before():
for n in reversed(bw_placeholders):
new_placeholder = multiplexed_gm.graph.placeholder(n.name)
new_placeholder.meta = n.meta
new_placeholder.target = new_placeholder.name
old_node_to_new_node[n] = new_placeholder

# Find the last placeholder and the output node in the multiplexed graph
insert_point = None
multiplexed_graph_op_node = None
for n in multiplexed_gm.graph.nodes:
if n.op == "placeholder":
insert_point = n
if n.op == "output":
multiplexed_graph_op_node = n

# Copy all computation nodes from backward graph into multiplexed graph
bw_graph_op_node = None
for n in bw_gm.graph.nodes:
if n.op == "placeholder":
continue
if n.op == "output":
bw_graph_op_node = n
continue
with multiplexed_gm.graph.inserting_after(insert_point):
# Copy node and remap its arguments using the node mapping
new_node = multiplexed_gm.graph.node_copy(
n, lambda x: old_node_to_new_node[x]
)
new_node.meta = n.meta
old_node_to_new_node[n] = new_node
insert_point = new_node

assert bw_graph_op_node is not None
assert multiplexed_graph_op_node is not None

# Collect output arguments from backward graph, remapping to new nodes
bw_op_node_args = [
old_node_to_new_node[n] if n is not None else None
for n in bw_graph_op_node.args[0]
]

# Collect output arguments from forward graph
fw_op_node_args = list(multiplexed_graph_op_node.args[0])

# Remove the old output node and create new combined output
insert_point = multiplexed_graph_op_node.prev
multiplexed_gm.graph.erase_node(multiplexed_graph_op_node)

# Create combined output with backward outputs first, then forward outputs
with multiplexed_gm.graph.inserting_after(insert_point):
multiplexed_gm.graph.output(bw_op_node_args + fw_op_node_args)

multiplexed_gm.graph.eliminate_dead_code()
multiplexed_gm.graph.lint()
multiplexed_gm.recompile()
return multiplexed_gm
50 changes: 50 additions & 0 deletions autoparallel/_passes/split_di_dw_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import copy

import torch
import torch.fx as fx
from functorch.compile import default_partition

# we are running the default partitioner on the bw graph, which requires AC tags being removed.
# At this stage we have already finished running AC anyway, since we have a bw graph
def remove_recompute_tags(bw_gm):
for n in bw_gm.graph.nodes:
if 'recompute' in n.meta:
del n.meta['recompute']

# We are using the default partitioner to split our backward into dI and dW subgraphs.
# We want to generate the dI subgraph *first*, because:
# - in pipelining we generally want to schedule dI compute before dW
# - the dI compute will potentially compute more activations that we need to plumb into dW compute
# Today, the default partitioner requires that your split on the first K outputs of your combined graph.
# So here, we reorder the outputs of the backward so grad_inputs are first.
def reorder_output_grads(bw_gm, num_weight_gradients):
outputs = bw_gm.graph.find_nodes(op='output')
assert len(outputs) == 1
output = outputs[0]
assert isinstance(output.args[0], tuple)
grad_weights, grad_inputs = output.args[0][:num_weight_gradients], output.args[0][num_weight_gradients:]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use descriptors here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should definitely use descriptors. I mostly wanted to have something Sanket could try out quickly but happy to refactor.

new_out_tuple = grad_inputs + grad_weights
with bw_gm.graph.inserting_after(output):
# TODO: also set the new node's meta properly
new_out = bw_gm.graph.output(new_out_tuple)
output.replace_all_uses_with(new_out)
bw_gm.graph.erase_node(output)
return len(grad_inputs)

# TODO: in theory we can infer num_weight_gradients from the graph metadata directly
def split_di_dw_graph(bw_gm: fx.GraphModule, *, num_weight_gradients) -> tuple[fx.GraphModule, fx.GraphModule]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Help me understand what bw_gm is; is this ONLY the linear?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So my understanding is that bw_gm here should be the entire backward graph of a given pipeline stage. A few things to note in this specific example file though:

(1) I am not using any kind of pipeline carving logic to get a fw/bw graph of a single stage, my dumb test is just generating a fw/bw graph of the entire user model. I figured this is more self contained, but it would be good to have an e2e test using the pipeline splitting logic too.

(2) technically, my code is currently performing the graph-split after Ivan's graph pass to split out FSDP allgathers from the fw and bw into a separate epilogue. In practice the order should probably not matter here, but one thing I noticed is that after running his pass, tangents_X inputs to the backward no longer show up in the main graph (left a comment here #201 (comment))

# we could consider doing this is a non-mutating way
bw_gm = copy.deepcopy(bw_gm)
remove_recompute_tags(bw_gm)
num_input_gradients = reorder_output_grads(bw_gm, num_weight_gradients)
bw_gm.recompile()

args = [x.meta['val'] for x in bw_gm.graph.find_nodes(op="placeholder")]

bw_inputs, bw_weights = default_partition(bw_gm, args, num_fwd_outputs=num_input_gradients)
return bw_inputs, bw_weights
61 changes: 61 additions & 0 deletions autoparallel/_passes/split_fsdp_collectives.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import dataclasses

import torch
import torch.utils._pytree as pytree
from torch._functorch._aot_autograd.descriptors import AOTOutput
from torch._functorch.partitioners import _extract_graph_with_inputs_outputs


@dataclasses.dataclass(frozen=True)
class PrefetchOutput(AOTOutput):
pass


def split_fsdp_prefetch(
gm: torch.fx.GraphModule,
) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule]:
g = gm.graph
g_ins = g.find_nodes(op="placeholder")
prefetch_g_outs_map = {}

for g_in in g_ins:
n = g_in
while True:
if len(n.users) != 1:
break
user = next(iter(n.users))
if len(user.all_input_nodes) > 1:
break
n = user
prefetch_g_outs_map[g_in] = n

prefetch_g_outs = list(prefetch_g_outs_map.values())
prefetch_g_outs_descs: list[AOTOutput] = [
PrefetchOutput() for _ in range(len(prefetch_g_outs))
]

prefetch_g = _extract_graph_with_inputs_outputs(
g,
g_ins,
prefetch_g_outs,
prefetch_g_outs_descs,
)

g_outs = pytree.arg_tree_leaves(*(n.args for n in g.find_nodes(op="output")))
g_outs_descs = pytree.arg_tree_leaves(
next(iter(g.find_nodes(op="output"))).meta.get("desc", [None] * len(g_outs))
)
main_g = _extract_graph_with_inputs_outputs(
g,
prefetch_g_outs,
g_outs,
g_outs_descs,
)
main_gm = torch.fx._lazy_graph_module._make_graph_module(gm, main_g)
prefetch_gm = torch.fx._lazy_graph_module._make_graph_module(gm, prefetch_g)
return prefetch_gm, main_gm
5 changes: 4 additions & 1 deletion autoparallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,11 +467,14 @@ def apply_placement(self, sharding_placement=None):
torch.ops._c10d_functional.wait_tensor.default
)

self.parallel_model_fn = parallel_model_fn = aot_compile_joint_with_descriptors(
parallel_model_fn, fw_module, bw_module = aot_compile_joint_with_descriptors(
self.joint_with_descriptors,
fw_compiler=self.compiler_fn,
bw_compiler=self.compiler_fn,
)
self.parallel_model_fn = parallel_model_fn
self.fw_module = fw_module
self.bw_module = bw_module

# TODO: this probably belongs in the AOTAutograd API
# TODO: pytree handling
Expand Down
25 changes: 24 additions & 1 deletion examples/example_llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from torch.distributed.tensor.placement_types import Partial, Replicate, Shard
from torch.testing._internal.distributed.fake_pg import FakeStore

from autoparallel._passes.graph_multiplex import multiplex_fw_bw_graph
from autoparallel._passes.split_fsdp_collectives import split_fsdp_prefetch
from autoparallel._testing.models.llama3 import Transformer, TransformerModelArgs
from autoparallel.api import AutoParallel
from autoparallel.auto_bucketing import (
Expand Down Expand Up @@ -57,7 +59,7 @@ def model_fn():
if model_type == "8b":
model_args = TransformerModelArgs(
dim=4096,
n_layers=32,
n_layers=1,
n_heads=32,
n_kv_heads=8,
ffn_dim_multiplier=1.3,
Expand Down Expand Up @@ -252,6 +254,27 @@ def _pass(graph):
sharding_placement = autop.optimize_placement(verbose=True)
print(f"Took {time.time() - t:.2f} s")
parallel_mod = autop.apply_placement(sharding_placement)
multiplex_graph = True
if multiplex_graph:
f_gm = autop.fw_module
b_gm = autop.bw_module
print("Original Fwd Graph:")
print(f_gm.graph)
print("Original Bwd Graph:")
print(b_gm.graph)
prefetch_f_gm, main_f_gm = split_fsdp_prefetch(f_gm)
print("Main Fwd Graph:")
print(main_f_gm.graph)
print("Prefetch Fwd Graph:")
print(prefetch_f_gm.graph)
prefetch_b_gm, main_b_gm = split_fsdp_prefetch(b_gm)
print("Main Bwd Graph:")
print(main_b_gm.graph)
print("Prefetch Bwd Graph:")
print(prefetch_b_gm.graph)
multiplexed_gm = multiplex_fw_bw_graph(main_f_gm, main_b_gm)
print("Multiplexed Graph:")
print(multiplexed_gm.graph)

# run weight init on our sharded DTensor params
parallel_mod.to_empty(device="cuda")
Expand Down
Loading
Loading