-
Notifications
You must be signed in to change notification settings - Fork 8
graph pass dI/dW split example #212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,105 @@ | ||
| # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import copy | ||
|
|
||
| import torch | ||
| import torch.fx as fx | ||
|
|
||
|
|
||
| def multiplex_fw_bw_graph( | ||
| fw_gm: fx.GraphModule, bw_gm: fx.GraphModule | ||
| ) -> fx.GraphModule: | ||
| """ | ||
| Multiplexes forward and backward graphs into a single unified graph module. | ||
|
|
||
| This function combines a forward graph and a backward graph into one multiplexed | ||
| graph by merging their nodes and outputs. The resulting graph has: | ||
| - All placeholders from both forward and backward graphs (backward followed by forward) | ||
| - All computation nodes from both graphs (backward followed by forward) | ||
| - Combined outputs (backward outputs followed by forward outputs) | ||
|
|
||
| Args: | ||
| fw_gm: The forward graph module containing the forward computation | ||
| bw_gm: The backward graph module containing the backward computation | ||
|
|
||
| Returns: | ||
| A multiplexed fx.GraphModule containing both forward and backward computations | ||
| with backward outputs appearing before forward outputs | ||
|
|
||
| Note: | ||
| The function preserves node metadata during the merging process. | ||
| """ | ||
| # Mapping to track correspondence between backward graph nodes and new nodes | ||
| old_node_to_new_node: dict[torch.fx.Node, torch.fx.Node] = {} | ||
|
|
||
| # Start with a deep copy of the forward graph as the base | ||
| multiplexed_gm = copy.deepcopy(fw_gm) | ||
|
|
||
| # Collect all placeholder nodes from the backward graph | ||
| bw_placeholders = [] | ||
| for n in bw_gm.graph.nodes: | ||
| if n.op == "placeholder": | ||
| bw_placeholders.append(n) | ||
|
|
||
| # Insert backward placeholders at the beginning of the multiplexed graph | ||
| # Reversed order ensures correct execution sequence | ||
| with multiplexed_gm.graph.inserting_before(): | ||
| for n in reversed(bw_placeholders): | ||
| new_placeholder = multiplexed_gm.graph.placeholder(n.name) | ||
| new_placeholder.meta = n.meta | ||
| new_placeholder.target = new_placeholder.name | ||
| old_node_to_new_node[n] = new_placeholder | ||
|
|
||
| # Find the last placeholder and the output node in the multiplexed graph | ||
| insert_point = None | ||
| multiplexed_graph_op_node = None | ||
| for n in multiplexed_gm.graph.nodes: | ||
| if n.op == "placeholder": | ||
| insert_point = n | ||
| if n.op == "output": | ||
| multiplexed_graph_op_node = n | ||
|
|
||
| # Copy all computation nodes from backward graph into multiplexed graph | ||
| bw_graph_op_node = None | ||
| for n in bw_gm.graph.nodes: | ||
| if n.op == "placeholder": | ||
| continue | ||
| if n.op == "output": | ||
| bw_graph_op_node = n | ||
| continue | ||
| with multiplexed_gm.graph.inserting_after(insert_point): | ||
| # Copy node and remap its arguments using the node mapping | ||
| new_node = multiplexed_gm.graph.node_copy( | ||
| n, lambda x: old_node_to_new_node[x] | ||
| ) | ||
| new_node.meta = n.meta | ||
| old_node_to_new_node[n] = new_node | ||
| insert_point = new_node | ||
|
|
||
| assert bw_graph_op_node is not None | ||
| assert multiplexed_graph_op_node is not None | ||
|
|
||
| # Collect output arguments from backward graph, remapping to new nodes | ||
| bw_op_node_args = [ | ||
| old_node_to_new_node[n] if n is not None else None | ||
| for n in bw_graph_op_node.args[0] | ||
| ] | ||
|
|
||
| # Collect output arguments from forward graph | ||
| fw_op_node_args = list(multiplexed_graph_op_node.args[0]) | ||
|
|
||
| # Remove the old output node and create new combined output | ||
| insert_point = multiplexed_graph_op_node.prev | ||
| multiplexed_gm.graph.erase_node(multiplexed_graph_op_node) | ||
|
|
||
| # Create combined output with backward outputs first, then forward outputs | ||
| with multiplexed_gm.graph.inserting_after(insert_point): | ||
| multiplexed_gm.graph.output(bw_op_node_args + fw_op_node_args) | ||
|
|
||
| multiplexed_gm.graph.eliminate_dead_code() | ||
| multiplexed_gm.graph.lint() | ||
| multiplexed_gm.recompile() | ||
| return multiplexed_gm | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,50 @@ | ||
| # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import copy | ||
|
|
||
| import torch | ||
| import torch.fx as fx | ||
| from functorch.compile import default_partition | ||
|
|
||
| # we are running the default partitioner on the bw graph, which requires AC tags being removed. | ||
| # At this stage we have already finished running AC anyway, since we have a bw graph | ||
| def remove_recompute_tags(bw_gm): | ||
| for n in bw_gm.graph.nodes: | ||
| if 'recompute' in n.meta: | ||
| del n.meta['recompute'] | ||
|
|
||
| # We are using the default partitioner to split our backward into dI and dW subgraphs. | ||
| # We want to generate the dI subgraph *first*, because: | ||
| # - in pipelining we generally want to schedule dI compute before dW | ||
| # - the dI compute will potentially compute more activations that we need to plumb into dW compute | ||
| # Today, the default partitioner requires that your split on the first K outputs of your combined graph. | ||
| # So here, we reorder the outputs of the backward so grad_inputs are first. | ||
| def reorder_output_grads(bw_gm, num_weight_gradients): | ||
| outputs = bw_gm.graph.find_nodes(op='output') | ||
| assert len(outputs) == 1 | ||
| output = outputs[0] | ||
| assert isinstance(output.args[0], tuple) | ||
| grad_weights, grad_inputs = output.args[0][:num_weight_gradients], output.args[0][num_weight_gradients:] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not use descriptors here?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I should definitely use descriptors. I mostly wanted to have something Sanket could try out quickly but happy to refactor. |
||
| new_out_tuple = grad_inputs + grad_weights | ||
| with bw_gm.graph.inserting_after(output): | ||
| # TODO: also set the new node's meta properly | ||
| new_out = bw_gm.graph.output(new_out_tuple) | ||
| output.replace_all_uses_with(new_out) | ||
| bw_gm.graph.erase_node(output) | ||
| return len(grad_inputs) | ||
|
|
||
| # TODO: in theory we can infer num_weight_gradients from the graph metadata directly | ||
| def split_di_dw_graph(bw_gm: fx.GraphModule, *, num_weight_gradients) -> tuple[fx.GraphModule, fx.GraphModule]: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Help me understand what bw_gm is; is this ONLY the linear?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So my understanding is that (1) I am not using any kind of pipeline carving logic to get a fw/bw graph of a single stage, my dumb test is just generating a fw/bw graph of the entire user model. I figured this is more self contained, but it would be good to have an e2e test using the pipeline splitting logic too. (2) technically, my code is currently performing the graph-split after Ivan's graph pass to split out FSDP allgathers from the fw and bw into a separate epilogue. In practice the order should probably not matter here, but one thing I noticed is that after running his pass, |
||
| # we could consider doing this is a non-mutating way | ||
| bw_gm = copy.deepcopy(bw_gm) | ||
| remove_recompute_tags(bw_gm) | ||
| num_input_gradients = reorder_output_grads(bw_gm, num_weight_gradients) | ||
| bw_gm.recompile() | ||
|
|
||
| args = [x.meta['val'] for x in bw_gm.graph.find_nodes(op="placeholder")] | ||
|
|
||
| bw_inputs, bw_weights = default_partition(bw_gm, args, num_fwd_outputs=num_input_gradients) | ||
| return bw_inputs, bw_weights | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,61 @@ | ||
| # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import dataclasses | ||
|
|
||
| import torch | ||
| import torch.utils._pytree as pytree | ||
| from torch._functorch._aot_autograd.descriptors import AOTOutput | ||
| from torch._functorch.partitioners import _extract_graph_with_inputs_outputs | ||
|
|
||
|
|
||
| @dataclasses.dataclass(frozen=True) | ||
| class PrefetchOutput(AOTOutput): | ||
| pass | ||
|
|
||
|
|
||
| def split_fsdp_prefetch( | ||
| gm: torch.fx.GraphModule, | ||
| ) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule]: | ||
| g = gm.graph | ||
| g_ins = g.find_nodes(op="placeholder") | ||
| prefetch_g_outs_map = {} | ||
|
|
||
| for g_in in g_ins: | ||
| n = g_in | ||
| while True: | ||
| if len(n.users) != 1: | ||
| break | ||
| user = next(iter(n.users)) | ||
| if len(user.all_input_nodes) > 1: | ||
| break | ||
| n = user | ||
| prefetch_g_outs_map[g_in] = n | ||
|
|
||
| prefetch_g_outs = list(prefetch_g_outs_map.values()) | ||
| prefetch_g_outs_descs: list[AOTOutput] = [ | ||
| PrefetchOutput() for _ in range(len(prefetch_g_outs)) | ||
| ] | ||
|
|
||
| prefetch_g = _extract_graph_with_inputs_outputs( | ||
| g, | ||
| g_ins, | ||
| prefetch_g_outs, | ||
| prefetch_g_outs_descs, | ||
| ) | ||
|
|
||
| g_outs = pytree.arg_tree_leaves(*(n.args for n in g.find_nodes(op="output"))) | ||
| g_outs_descs = pytree.arg_tree_leaves( | ||
| next(iter(g.find_nodes(op="output"))).meta.get("desc", [None] * len(g_outs)) | ||
| ) | ||
| main_g = _extract_graph_with_inputs_outputs( | ||
| g, | ||
| prefetch_g_outs, | ||
| g_outs, | ||
| g_outs_descs, | ||
| ) | ||
| main_gm = torch.fx._lazy_graph_module._make_graph_module(gm, main_g) | ||
| prefetch_gm = torch.fx._lazy_graph_module._make_graph_module(gm, prefetch_g) | ||
| return prefetch_gm, main_gm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The only real changes in this PR are
example_llama3_di_dw.pyandsplit_di_dw_graph.py, the rest is an artifact of me working on top of #205 and making the PR against main. I can move this graph pass somewhere else and/or wait for that other PR to land as necessaryThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would say that it would be good to start landing things instead of working from branches of branches.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed. I was prioritizing "having a graph pass that Sanket can dump into his pipeline runtime" over having a PR that is ready to land. I'm happy to clean up this example to make it more self contained + make it a proper test.