|
4 | 4 | # LICENSE file in the root directory of this source tree. |
5 | 5 |
|
6 | 6 | import copy |
| 7 | +import itertools |
| 8 | +import operator |
7 | 9 |
|
| 10 | +import sympy |
| 11 | +import torch |
8 | 12 | import torch.fx as fx |
9 | | -from functorch.compile import default_partition |
| 13 | +from torch._functorch.partitioners import ( |
| 14 | + SavedForBackwardsAOTOutput, |
| 15 | + _extract_fwd_bwd_outputs, |
| 16 | + _extract_graph_with_inputs_outputs, |
| 17 | + _is_backward_state, |
| 18 | + _is_bwd_seed_offset, |
| 19 | + _is_fwd_seed_offset, |
| 20 | + _is_primal, |
| 21 | + _remove_by_name, |
| 22 | + find_symbol_binding_fx_nodes, |
| 23 | + free_symbols, |
| 24 | + is_sym_node, |
| 25 | + is_symbol_binding_fx_node, |
| 26 | +) |
| 27 | +from torch.utils._ordered_set import OrderedSet |
| 28 | + |
| 29 | +from autoparallel.apply_sharding import rename_placeholder_node |
10 | 30 |
|
11 | 31 | # we are running the default partitioner on the bw graph, which requires AC tags being removed. |
12 | 32 | # At this stage we have already finished running AC anyway, since we have a bw graph |
@@ -44,21 +64,203 @@ def reorder_output_grads(bw_gm, num_weight_gradients): |
44 | 64 | return len(grad_inputs) |
45 | 65 |
|
46 | 66 |
|
47 | | -# TODO: in theory we can infer num_weight_gradients from the graph metadata directly |
| 67 | +# This is a copy of the function used by the default partitioner, |
| 68 | +# which does *not* reorder symint activations. |
| 69 | +# This is reordering is needed by the custom autograd.Function in AOTDispatcher, |
| 70 | +# but isn't needed in our dI/dW splitting since there is no autograd in the loop. |
| 71 | +# TODO: provide a way to gt this behavior automatically out of the default partitioner |
| 72 | +def _extract_fwd_bwd_modules( |
| 73 | + joint_module: fx.GraphModule, |
| 74 | + saved_values: list[fx.Node], |
| 75 | + saved_sym_nodes: list[fx.Node], |
| 76 | + *, |
| 77 | + num_fwd_outputs: int, |
| 78 | +) -> tuple[fx.GraphModule, fx.GraphModule]: |
| 79 | + ( |
| 80 | + fwd_outputs, |
| 81 | + bwd_outputs, |
| 82 | + fwd_outputs_descs, |
| 83 | + bwd_outputs_descs, |
| 84 | + ) = _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs) |
| 85 | + placeholders = joint_module.graph.find_nodes(op="placeholder") |
| 86 | + primal_inputs = [*filter(_is_primal, placeholders)] |
| 87 | + fwd_seed_offset_inputs = [*filter(_is_fwd_seed_offset, placeholders)] |
| 88 | + bwd_seed_offset_inputs = [*filter(_is_bwd_seed_offset, placeholders)] |
| 89 | + backward_state_inputs = [*filter(_is_backward_state, placeholders)] |
| 90 | + |
| 91 | + bwd_graph = _extract_graph_with_inputs_outputs( |
| 92 | + joint_module.graph, |
| 93 | + saved_values + saved_sym_nodes + bwd_seed_offset_inputs, |
| 94 | + bwd_outputs, |
| 95 | + bwd_outputs_descs, |
| 96 | + "backward", |
| 97 | + ignore_must_be_in_fw_bw=True, |
| 98 | + ) |
| 99 | + |
| 100 | + distributed_enabled = torch.distributed.is_available() |
| 101 | + |
| 102 | + for node in bwd_graph.find_nodes(op="placeholder"): |
| 103 | + # This is to filter out saved values that don't actually end up being used by the backwards pass |
| 104 | + if not node.users: |
| 105 | + _remove_by_name(saved_values, node.name) |
| 106 | + _remove_by_name(saved_sym_nodes, node.name) |
| 107 | + # wait_tensor is a bit special: if we have a "dead activation" that is not used in the bw, |
| 108 | + # but this dead activation is actually a collective, |
| 109 | + # then the collective will generally by followed by a wait_tensor() call. |
| 110 | + # we need to peak one node further to see if this wait_tensor is dead as well. |
| 111 | + elif distributed_enabled and all( |
| 112 | + n.target is torch.ops._c10d_functional.wait_tensor.default |
| 113 | + and len(n.users) == 0 |
| 114 | + for n in node.users |
| 115 | + ): |
| 116 | + _remove_by_name(saved_values, node.name) |
| 117 | + _remove_by_name(saved_sym_nodes, node.name) |
| 118 | + elif _is_backward_state(node): |
| 119 | + # BackwardState is saved directly |
| 120 | + _remove_by_name(saved_values, node.name) |
| 121 | + assert backward_state_inputs |
| 122 | + |
| 123 | + # Now that we have the finalized list of saved values, we need to ensure |
| 124 | + # we propagate all symbols which are referenced by backwards inputs. |
| 125 | + # These are not directly used in the graph but are required for downstream |
| 126 | + # sizevar assignment |
| 127 | + saved_symbols: OrderedSet[sympy.Symbol] = OrderedSet() |
| 128 | + saved_sym_nodes_binding = [] |
| 129 | + saved_sym_nodes_derived = [] |
| 130 | + |
| 131 | + # Some symbols may already be bound in the directly saved_sym_nodes, |
| 132 | + # keep track of them so we don't re-bind them |
| 133 | + for node in saved_sym_nodes: |
| 134 | + symbol = is_symbol_binding_fx_node(node) |
| 135 | + if symbol: |
| 136 | + saved_symbols.add(symbol) |
| 137 | + saved_sym_nodes_binding.append(node) |
| 138 | + else: |
| 139 | + saved_sym_nodes_derived.append(node) |
| 140 | + |
| 141 | + # Now go through all of the prospective backward inputs and track any |
| 142 | + # other symbols we need to bind |
| 143 | + symbol_bindings = find_symbol_binding_fx_nodes(joint_module.graph) |
| 144 | + for node in itertools.chain(saved_sym_nodes_derived, saved_values): |
| 145 | + if "val" not in node.meta: |
| 146 | + continue |
| 147 | + new_symbols = free_symbols(node.meta["val"]) - saved_symbols |
| 148 | + # NB: Deterministic order please! |
| 149 | + for s in sorted(new_symbols, key=lambda s: s.name): |
| 150 | + # NB: For well formed graphs, the symbol should always be present, |
| 151 | + # but we also have ways to produce ill-formed graphs, e.g., direct |
| 152 | + # make_fx usages, so don't choke in this case |
| 153 | + if s not in symbol_bindings: |
| 154 | + continue |
| 155 | + saved_sym_nodes_binding.append(symbol_bindings[s]) |
| 156 | + saved_symbols |= new_symbols |
| 157 | + |
| 158 | + # Update saved_sym_nodes that are now reordered to have all bindings at |
| 159 | + # front. This can also be used later on to figure out the position of saved |
| 160 | + # sym nodes in the output of fwd graph. |
| 161 | + saved_sym_nodes.clear() |
| 162 | + saved_sym_nodes.extend(saved_sym_nodes_binding + saved_sym_nodes_derived) |
48 | 163 |
|
| 164 | + # Now, we re-generate the fwd/bwd graphs. |
| 165 | + # NB: This might increase compilation time, but I doubt it matters |
| 166 | + fwd_graph = _extract_graph_with_inputs_outputs( |
| 167 | + joint_module.graph, |
| 168 | + primal_inputs + fwd_seed_offset_inputs, |
| 169 | + fwd_outputs + saved_values + saved_sym_nodes, |
| 170 | + fwd_outputs_descs |
| 171 | + + [ |
| 172 | + SavedForBackwardsAOTOutput(i) |
| 173 | + for i in range(len(saved_values) + len(saved_sym_nodes)) |
| 174 | + ], |
| 175 | + "forward", |
| 176 | + ignore_must_be_in_fw_bw=True, |
| 177 | + ) |
| 178 | + bwd_graph = _extract_graph_with_inputs_outputs( |
| 179 | + joint_module.graph, |
| 180 | + saved_values + saved_sym_nodes + bwd_seed_offset_inputs + backward_state_inputs, |
| 181 | + bwd_outputs, |
| 182 | + bwd_outputs_descs, |
| 183 | + "backward", |
| 184 | + ignore_must_be_in_fw_bw=True, |
| 185 | + ) |
| 186 | + |
| 187 | + fwd_module = fx._lazy_graph_module._make_graph_module(joint_module, fwd_graph) |
| 188 | + bwd_module = fx._lazy_graph_module._make_graph_module(joint_module, bwd_graph) |
| 189 | + return fwd_module, bwd_module |
49 | 190 |
|
| 191 | + |
| 192 | +# TODO: in theory we can infer num_weight_gradients from the graph metadata directly |
50 | 193 | def split_di_dw_graph( |
51 | | - bw_gm: fx.GraphModule, *, num_weight_gradients |
52 | | -) -> tuple[fx.GraphModule, fx.GraphModule]: |
| 194 | + bw_gm_old: fx.GraphModule, *, num_weight_gradients: int |
| 195 | +) -> tuple[fx.GraphModule, fx.GraphModule, int]: |
53 | 196 | # we could consider doing this is a non-mutating way |
54 | | - bw_gm = copy.deepcopy(bw_gm) |
| 197 | + bw_gm = copy.deepcopy(bw_gm_old) |
| 198 | + placeholders = bw_gm.graph.find_nodes(op="placeholder") |
| 199 | + for p in placeholders: |
| 200 | + if p.name.startswith("tangent"): |
| 201 | + name_suffix = p.name[8:] |
| 202 | + rename_placeholder_node(bw_gm, p, f"not_tngnt{name_suffix}") |
| 203 | + |
55 | 204 | remove_recompute_tags(bw_gm) |
56 | 205 | num_input_gradients = reorder_output_grads(bw_gm, num_weight_gradients) |
57 | 206 | bw_gm.recompile() |
58 | 207 |
|
59 | | - args = [x.meta["val"] for x in bw_gm.graph.find_nodes(op="placeholder")] |
| 208 | + args = list(bw_gm.graph.find_nodes(op="placeholder")) |
| 209 | + |
| 210 | + # bw_inputs, bw_weights = default_partition(bw_gm, args, num_fwd_outputs=num_input_gradients) |
| 211 | + # return bw_inputs, bw_weights, num_input_gradients |
| 212 | + |
| 213 | + ( |
| 214 | + grad_inps, |
| 215 | + grad_weights, |
| 216 | + grad_inp_descs, |
| 217 | + grad_weight_descs, |
| 218 | + ) = _extract_fwd_bwd_outputs(bw_gm, num_fwd_outputs=num_input_gradients) |
| 219 | + bw_inputs_gm = _extract_graph_with_inputs_outputs( |
| 220 | + bw_gm.graph, |
| 221 | + args, |
| 222 | + grad_inps, |
| 223 | + grad_inp_descs, |
| 224 | + "forward", |
| 225 | + ignore_must_be_in_fw_bw=True, |
| 226 | + ) |
| 227 | + bw_inputs_gm_node_names = OrderedSet( |
| 228 | + node.name for node in bw_inputs_gm.nodes if node.op != "output" |
| 229 | + ) |
| 230 | + saved_values = [] |
| 231 | + saved_sym_nodes = [] |
60 | 232 |
|
61 | | - bw_inputs, bw_weights = default_partition( |
62 | | - bw_gm, args, num_fwd_outputs=num_input_gradients |
| 233 | + for node in bw_gm.graph.nodes: |
| 234 | + if node.name not in bw_inputs_gm_node_names: |
| 235 | + # Not handling mutations for now, |
| 236 | + # we can try to re-use more of and/or consolidate with default partitioner |
| 237 | + continue |
| 238 | + if is_sym_node(node): |
| 239 | + saved_sym_nodes.append(node) |
| 240 | + elif ( |
| 241 | + "tensor_meta" not in node.meta |
| 242 | + and node.op == "call_function" |
| 243 | + and not isinstance(node.meta.get("val"), torch._subclasses.FakeTensor) |
| 244 | + ): |
| 245 | + users = node.users |
| 246 | + assert all(user.target == operator.getitem for user in users) |
| 247 | + saved_values.extend(users) |
| 248 | + else: |
| 249 | + backward_usages = [ |
| 250 | + n for n in node.users if n.name not in bw_inputs_gm_node_names |
| 251 | + ] |
| 252 | + if "tensor_meta" in node.meta and all( |
| 253 | + is_sym_node(n) for n in backward_usages |
| 254 | + ): |
| 255 | + saved_sym_nodes.extend(backward_usages) |
| 256 | + else: |
| 257 | + saved_values.append(node) |
| 258 | + saved_values = list(dict.fromkeys(saved_values).keys()) |
| 259 | + saved_sym_nodes = list(dict.fromkeys(saved_sym_nodes).keys()) |
| 260 | + bw_inputs, bw_weights = _extract_fwd_bwd_modules( |
| 261 | + bw_gm, |
| 262 | + saved_values, |
| 263 | + saved_sym_nodes=saved_sym_nodes, |
| 264 | + num_fwd_outputs=num_input_gradients, |
63 | 265 | ) |
64 | | - return bw_inputs, bw_weights |
| 266 | + return bw_inputs, bw_weights, num_input_gradients |
0 commit comments