|
| 1 | +import copy |
| 2 | + |
| 3 | +import torch |
| 4 | +import torch.fx as fx |
| 5 | + |
| 6 | + |
| 7 | +def multiplex_fw_bw_graph( |
| 8 | + fw_gm: fx.GraphModule, bw_gm: fx.GraphModule |
| 9 | +) -> fx.GraphModule: |
| 10 | + """ |
| 11 | + Multiplexes forward and backward graphs into a single unified graph module. |
| 12 | +
|
| 13 | + This function combines a forward graph and a backward graph into one multiplexed |
| 14 | + graph by merging their nodes and outputs. The resulting graph has: |
| 15 | + - All placeholders from both forward and backward graphs (backward followed by forward) |
| 16 | + - All computation nodes from both graphs (backward followed by forward) |
| 17 | + - Combined outputs (backward outputs followed by forward outputs) |
| 18 | +
|
| 19 | + Args: |
| 20 | + fw_gm: The forward graph module containing the forward computation |
| 21 | + bw_gm: The backward graph module containing the backward computation |
| 22 | +
|
| 23 | + Returns: |
| 24 | + A multiplexed fx.GraphModule containing both forward and backward computations |
| 25 | + with backward outputs appearing before forward outputs |
| 26 | +
|
| 27 | + Note: |
| 28 | + The function preserves node metadata during the merging process. |
| 29 | + """ |
| 30 | + # Mapping to track correspondence between backward graph nodes and new nodes |
| 31 | + old_node_to_new_node: dict[torch.fx.Node, torch.fx.Node] = {} |
| 32 | + |
| 33 | + # Start with a deep copy of the forward graph as the base |
| 34 | + multiplexed_gm = copy.deepcopy(fw_gm) |
| 35 | + |
| 36 | + # Collect all placeholder nodes from the backward graph |
| 37 | + bw_placeholders = [] |
| 38 | + for n in bw_gm.graph.nodes: |
| 39 | + if n.op == "placeholder": |
| 40 | + bw_placeholders.append(n) |
| 41 | + |
| 42 | + # Insert backward placeholders at the beginning of the multiplexed graph |
| 43 | + # Reversed order ensures correct execution sequence |
| 44 | + with multiplexed_gm.graph.inserting_before(): |
| 45 | + for n in reversed(bw_placeholders): |
| 46 | + new_placeholder = multiplexed_gm.graph.placeholder(n.name) |
| 47 | + new_placeholder.meta = n.meta |
| 48 | + new_placeholder.target = new_placeholder.name |
| 49 | + old_node_to_new_node[n] = new_placeholder |
| 50 | + |
| 51 | + # Find the last placeholder and the output node in the multiplexed graph |
| 52 | + insert_point = None |
| 53 | + multiplexed_graph_op_node = None |
| 54 | + for n in multiplexed_gm.graph.nodes: |
| 55 | + if n.op == "placeholder": |
| 56 | + insert_point = n |
| 57 | + if n.op == "output": |
| 58 | + multiplexed_graph_op_node = n |
| 59 | + |
| 60 | + # Copy all computation nodes from backward graph into multiplexed graph |
| 61 | + bw_graph_op_node = None |
| 62 | + for n in bw_gm.graph.nodes: |
| 63 | + if n.op == "placeholder": |
| 64 | + continue |
| 65 | + if n.op == "output": |
| 66 | + bw_graph_op_node = n |
| 67 | + continue |
| 68 | + with multiplexed_gm.graph.inserting_after(insert_point): |
| 69 | + # Copy node and remap its arguments using the node mapping |
| 70 | + new_node = multiplexed_gm.graph.node_copy( |
| 71 | + n, lambda x: old_node_to_new_node[x] |
| 72 | + ) |
| 73 | + new_node.meta = n.meta |
| 74 | + old_node_to_new_node[n] = new_node |
| 75 | + insert_point = new_node |
| 76 | + |
| 77 | + assert bw_graph_op_node is not None |
| 78 | + assert multiplexed_graph_op_node is not None |
| 79 | + |
| 80 | + # Collect output arguments from backward graph, remapping to new nodes |
| 81 | + bw_op_node_args = [ |
| 82 | + old_node_to_new_node[n] if n is not None else None |
| 83 | + for n in bw_graph_op_node.args[0] |
| 84 | + ] |
| 85 | + |
| 86 | + # Collect output arguments from forward graph |
| 87 | + fw_op_node_args = list(multiplexed_graph_op_node.args[0]) |
| 88 | + |
| 89 | + # Remove the old output node and create new combined output |
| 90 | + insert_point = multiplexed_graph_op_node.prev |
| 91 | + multiplexed_gm.graph.erase_node(multiplexed_graph_op_node) |
| 92 | + |
| 93 | + # Create combined output with backward outputs first, then forward outputs |
| 94 | + with multiplexed_gm.graph.inserting_after(insert_point): |
| 95 | + multiplexed_gm.graph.output(bw_op_node_args + fw_op_node_args) |
| 96 | + |
| 97 | + multiplexed_gm.graph.eliminate_dead_code() |
| 98 | + multiplexed_gm.graph.lint() |
| 99 | + multiplexed_gm.recompile() |
| 100 | + return multiplexed_gm |
0 commit comments