|
| 1 | +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. |
| 2 | +# |
| 3 | +# This source code is licensed under the BSD license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | + |
| 6 | +import dataclasses |
| 7 | +from contextlib import contextmanager |
| 8 | +from functools import partial |
| 9 | +from typing import Any |
| 10 | + |
| 11 | +import torch |
| 12 | +import torch.fx.node |
| 13 | +import torch.utils._pytree as pytree |
| 14 | +from torch._functorch._aot_autograd.descriptors import AOTOutput |
| 15 | +from torch._functorch.partitioners import _extract_graph_with_inputs_outputs |
| 16 | +from torch._inductor.fx_passes.bucketing import ( |
| 17 | + is_all_gather_into_tensor, |
| 18 | + is_reduce_scatter_tensor, |
| 19 | +) |
| 20 | + |
| 21 | + |
| 22 | +@contextmanager |
| 23 | +def exclude_from_fx_side_effectful(exclude_vals: set[Any]): |
| 24 | + original_val = torch.fx.node._side_effectful_functions.copy() |
| 25 | + try: |
| 26 | + torch.fx.node._side_effectful_functions -= exclude_vals |
| 27 | + yield |
| 28 | + finally: |
| 29 | + torch.fx.node._side_effectful_functions.clear() |
| 30 | + torch.fx.node._side_effectful_functions.update(original_val) |
| 31 | + |
| 32 | + |
| 33 | +exclude_wait_from_fx_side_effectful = partial( |
| 34 | + exclude_from_fx_side_effectful, |
| 35 | + { |
| 36 | + torch.ops._c10d_functional.wait_tensor, |
| 37 | + torch.ops._c10d_functional.wait_tensor.default, |
| 38 | + }, |
| 39 | +) |
| 40 | + |
| 41 | + |
| 42 | +@dataclasses.dataclass(frozen=True) |
| 43 | +class PrefetchOutput(AOTOutput): |
| 44 | + pass |
| 45 | + |
| 46 | + |
| 47 | +@dataclasses.dataclass(frozen=True) |
| 48 | +class EpilogueInput(AOTOutput): |
| 49 | + pass |
| 50 | + |
| 51 | + |
| 52 | +def split_fsdp_prefetch(g: torch.fx.Graph) -> tuple[torch.fx.Graph, torch.fx.Graph]: |
| 53 | + g_ins = g.find_nodes(op="placeholder") |
| 54 | + prefetch_g_outs_map = [] |
| 55 | + |
| 56 | + for g_in in g_ins: |
| 57 | + n = g_in |
| 58 | + last_ag = None |
| 59 | + while True: |
| 60 | + if len(n.users) != 1: |
| 61 | + break |
| 62 | + user = next(iter(n.users)) |
| 63 | + if len(user.all_input_nodes) > 1: |
| 64 | + break |
| 65 | + n = user |
| 66 | + if is_all_gather_into_tensor(n): |
| 67 | + last_ag = n |
| 68 | + if last_ag is None: |
| 69 | + prefetch_g_outs_map.append(g_in) |
| 70 | + else: |
| 71 | + w_n = next(iter(last_ag.users)) |
| 72 | + prefetch_g_outs_map.append(w_n) |
| 73 | + |
| 74 | + prefetch_g_outs = prefetch_g_outs_map |
| 75 | + prefetch_g_outs_descs: list[AOTOutput] = [ |
| 76 | + PrefetchOutput() for _ in range(len(prefetch_g_outs)) |
| 77 | + ] |
| 78 | + g_outs = pytree.arg_tree_leaves(*(n.args for n in g.find_nodes(op="output"))) |
| 79 | + g_outs_descs = pytree.arg_tree_leaves( |
| 80 | + next(iter(g.find_nodes(op="output"))).meta.get("desc", [None] * len(g_outs)) |
| 81 | + ) |
| 82 | + with exclude_wait_from_fx_side_effectful(): |
| 83 | + prefetch_g = _extract_graph_with_inputs_outputs( |
| 84 | + g, |
| 85 | + g_ins, |
| 86 | + prefetch_g_outs, |
| 87 | + prefetch_g_outs_descs, |
| 88 | + ignore_must_be_in_fw_bw=True, |
| 89 | + ) |
| 90 | + |
| 91 | + main_g = _extract_graph_with_inputs_outputs( |
| 92 | + g, |
| 93 | + prefetch_g_outs, |
| 94 | + g_outs, |
| 95 | + g_outs_descs, |
| 96 | + ignore_must_be_in_fw_bw=True, |
| 97 | + ) |
| 98 | + return prefetch_g, main_g |
| 99 | + |
| 100 | + |
| 101 | +def split_fsdp_reduce_scatters_epilogue( |
| 102 | + g: torch.fx.Graph, |
| 103 | +) -> tuple[torch.fx.Graph, torch.fx.Graph]: |
| 104 | + g_ins = g.find_nodes(op="placeholder") |
| 105 | + g_outs = pytree.arg_tree_leaves(*(n.args for n in g.find_nodes(op="output"))) |
| 106 | + g_outs_descs = pytree.arg_tree_leaves( |
| 107 | + next(iter(g.find_nodes(op="output"))).meta.get("desc", [None] * len(g_outs)) |
| 108 | + ) |
| 109 | + |
| 110 | + g_outs_map = [] |
| 111 | + for g_out in g_outs: |
| 112 | + n = g_out |
| 113 | + last_rs = None |
| 114 | + while n is not None: |
| 115 | + if len(n.all_input_nodes) != 1: |
| 116 | + break |
| 117 | + n_in = n.all_input_nodes[0] |
| 118 | + if len(n_in.users) > 1: |
| 119 | + break |
| 120 | + prev_n = n |
| 121 | + n = n_in |
| 122 | + if is_reduce_scatter_tensor(prev_n): |
| 123 | + # In AP for mesh dim > 1 |
| 124 | + # The reduction of gradients happen in multiple steps |
| 125 | + last_rs = n |
| 126 | + if last_rs is not None: |
| 127 | + g_outs_map.append(last_rs) |
| 128 | + else: |
| 129 | + g_outs_map.append(g_out) |
| 130 | + |
| 131 | + epi_g_ins = [n for n in g_outs_map if n is not None] |
| 132 | + epi_g_ins_descs: list[AOTOutput] = [EpilogueInput() for _ in range(len(epi_g_ins))] |
| 133 | + |
| 134 | + with exclude_wait_from_fx_side_effectful(): |
| 135 | + main_g = _extract_graph_with_inputs_outputs( |
| 136 | + g, |
| 137 | + g_ins, |
| 138 | + epi_g_ins, |
| 139 | + epi_g_ins_descs, |
| 140 | + ignore_must_be_in_fw_bw=True, |
| 141 | + ) |
| 142 | + epi_g = _extract_graph_with_inputs_outputs( |
| 143 | + g, |
| 144 | + epi_g_ins, |
| 145 | + g_outs, |
| 146 | + g_outs_descs, |
| 147 | + ignore_must_be_in_fw_bw=True, |
| 148 | + ) |
| 149 | + |
| 150 | + return main_g, epi_g |
0 commit comments