-
Notifications
You must be signed in to change notification settings - Fork 8
Pass to split all_gather prologue and reduce_scatter prologue from fsdp graph #201
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
IvanKobzarev
wants to merge
1
commit into
main
Choose a base branch
from
IvanKobzarev/stack/9
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,147 @@ | ||
| # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import dataclasses | ||
| from contextlib import contextmanager | ||
| from functools import partial | ||
| from typing import Any | ||
|
|
||
| import torch | ||
| import torch.fx.node | ||
| import torch.utils._pytree as pytree | ||
| from torch._functorch._aot_autograd.descriptors import AOTOutput | ||
| from torch._functorch.partitioners import _extract_graph_with_inputs_outputs | ||
| from torch._inductor.fx_passes.bucketing import ( | ||
| is_all_gather_into_tensor, | ||
| is_reduce_scatter_tensor, | ||
| ) | ||
|
|
||
|
|
||
| @contextmanager | ||
| def exclude_from_fx_side_effectful(exclude_vals: set[Any]): | ||
| original_val = torch.fx.node._side_effectful_functions.copy() | ||
| try: | ||
| torch.fx.node._side_effectful_functions -= exclude_vals | ||
| yield | ||
| finally: | ||
| torch.fx.node._side_effectful_functions.clear() | ||
| torch.fx.node._side_effectful_functions.update(original_val) | ||
|
|
||
|
|
||
| exclude_wait_from_fx_side_effectful = partial( | ||
| exclude_from_fx_side_effectful, | ||
| { | ||
| torch.ops._c10d_functional.wait_tensor, | ||
| torch.ops._c10d_functional.wait_tensor.default, | ||
| }, | ||
| ) | ||
|
|
||
|
|
||
| @dataclasses.dataclass(frozen=True) | ||
| class PrefetchOutput(AOTOutput): | ||
| pass | ||
|
|
||
|
|
||
| @dataclasses.dataclass(frozen=True) | ||
| class EpilogueInput(AOTOutput): | ||
| pass | ||
|
|
||
|
|
||
| def split_fsdp_prefetch( | ||
| g: torch.fx.Graph, stop_at_all_gather: bool = True | ||
| ) -> tuple[torch.fx.Graph, torch.fx.Graph]: | ||
| g_ins = g.find_nodes(op="placeholder") | ||
| prefetch_g_outs_map = [] | ||
|
|
||
| for g_in in g_ins: | ||
| n = g_in | ||
| has_ag = False | ||
| while True: | ||
| if len(n.users) != 1: | ||
| break | ||
|
Comment on lines
+62
to
+63
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This assumes that the partitioner has been changed to that we don't recompute the all-gather collectives in the backward pass. @sanketpurandare you'll need to keep this in mind for your PR |
||
| user = next(iter(n.users)) | ||
| if len(user.all_input_nodes) > 1: | ||
| break | ||
| n = user | ||
| if stop_at_all_gather and is_all_gather_into_tensor(n): | ||
| has_ag = True | ||
| w_n = next(iter(n.users)) | ||
| n = w_n | ||
| break | ||
| if stop_at_all_gather and not has_ag: | ||
| prefetch_g_outs_map.append(g_in) | ||
| else: | ||
| prefetch_g_outs_map.append(n) | ||
|
|
||
| prefetch_g_outs = prefetch_g_outs_map | ||
| prefetch_g_outs_descs: list[AOTOutput] = [ | ||
| PrefetchOutput() for _ in range(len(prefetch_g_outs)) | ||
| ] | ||
| g_outs = pytree.arg_tree_leaves(*(n.args for n in g.find_nodes(op="output"))) | ||
| g_outs_descs = pytree.arg_tree_leaves( | ||
| next(iter(g.find_nodes(op="output"))).meta.get("desc", [None] * len(g_outs)) | ||
| ) | ||
| with exclude_wait_from_fx_side_effectful(): | ||
| prefetch_g = _extract_graph_with_inputs_outputs( | ||
| g, | ||
| g_ins, | ||
| prefetch_g_outs, | ||
| prefetch_g_outs_descs, | ||
| ) | ||
|
|
||
| main_g = _extract_graph_with_inputs_outputs( | ||
| g, | ||
| prefetch_g_outs, | ||
| g_outs, | ||
| g_outs_descs, | ||
| ) | ||
| return prefetch_g, main_g | ||
|
|
||
|
|
||
| def split_fsdp_reduce_scatters_epilogue( | ||
| g: torch.fx.Graph, | ||
| ) -> tuple[torch.fx.Graph, torch.fx.Graph]: | ||
| g_ins = g.find_nodes(op="placeholder") | ||
| g_outs = pytree.arg_tree_leaves(*(n.args for n in g.find_nodes(op="output"))) | ||
| g_outs_descs = pytree.arg_tree_leaves( | ||
| next(iter(g.find_nodes(op="output"))).meta.get("desc", [None] * len(g_outs)) | ||
| ) | ||
|
|
||
| g_outs_map = [] | ||
| for g_out in g_outs: | ||
| n = g_out | ||
| has_rs = False | ||
| while n is not None: | ||
| if len(n.all_input_nodes) != 1: | ||
| break | ||
| n_in = n.all_input_nodes[0] | ||
| if len(n_in.users) > 1: | ||
| break | ||
| prev_n = n | ||
| n = n_in | ||
| if is_reduce_scatter_tensor(prev_n): | ||
| has_rs = True | ||
| break | ||
| if has_rs: | ||
| g_outs_map.append(n) | ||
| else: | ||
| g_outs_map.append(g_out) | ||
|
|
||
| epi_g_ins = [n for n in g_outs_map if n is not None] | ||
| epi_g_ins_descs: list[AOTOutput] = [EpilogueInput() for _ in range(len(epi_g_ins))] | ||
| main_g = _extract_graph_with_inputs_outputs( | ||
| g, | ||
| g_ins, | ||
| epi_g_ins, | ||
| epi_g_ins_descs, | ||
| ) | ||
| epi_g = _extract_graph_with_inputs_outputs( | ||
| g, | ||
| epi_g_ins, | ||
| g_outs, | ||
| g_outs_descs, | ||
| ) | ||
|
|
||
| return main_g, epi_g | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,100 @@ | ||
| from unittest.mock import patch | ||
|
|
||
| import pytest | ||
| import torch | ||
| from torch import nn | ||
| from torch.fx import GraphModule | ||
| from torch.testing._internal.distributed.fake_pg import FakeStore | ||
|
|
||
| from autoparallel.api import AutoParallel | ||
| from autoparallel.pipeline.passes import ( | ||
| split_fsdp_prefetch, | ||
| split_fsdp_reduce_scatters_epilogue, | ||
| ) | ||
|
|
||
|
|
||
| @pytest.fixture(scope="module", autouse=True) | ||
| def init_pg(): | ||
| world_size = 256 | ||
| fake_store = FakeStore() | ||
| if torch.distributed.is_initialized(): | ||
| return | ||
| torch.distributed.init_process_group( | ||
| "fake", store=fake_store, rank=0, world_size=world_size | ||
| ) | ||
|
|
||
|
|
||
| @pytest.fixture(scope="module") | ||
| def device_mesh_2d(): | ||
| world_size = torch.distributed.get_world_size() | ||
| mesh = torch.distributed.device_mesh.init_device_mesh( | ||
| "cuda", | ||
| (world_size // 8, 8), | ||
| mesh_dim_names=( | ||
| "dp", | ||
| "tp", | ||
| ), | ||
| ) | ||
| return mesh | ||
|
|
||
|
|
||
| class FFN(nn.Module): | ||
| def __init__(self, dim1, dim2): | ||
| super().__init__() | ||
| bias = False | ||
| self.linear1 = nn.Linear(dim1, dim2, bias=bias) | ||
| self.linear2 = nn.Linear(dim2, dim1, bias=bias) | ||
|
|
||
| def forward(self, x, y): | ||
| return y + 2, self.linear2(self.linear1(x)), y + 2 | ||
|
|
||
|
|
||
| def _make_model_and_input_fn(mesh, device="cuda"): | ||
| bs = 2048 * mesh.shape[0] | ||
| dim1 = 1024 | ||
| dim2 = 4096 | ||
|
|
||
| def model_fn(): | ||
| return FFN(dim1, dim2) | ||
|
|
||
| def input_fn(): | ||
| return torch.randn(bs, dim1).to(device), torch.randn(bs, 1).to(device) | ||
|
|
||
| return model_fn, input_fn | ||
|
|
||
|
|
||
| @patch("torch.cuda.device_count", lambda: 8) | ||
| @patch("torch.cuda.get_device_name", lambda device: "H100") | ||
| def test_fsdp_split_passes(device_mesh_2d): | ||
| low_mem = 0 | ||
| high_mem = None | ||
| model_fn, input_fn = _make_model_and_input_fn(device_mesh_2d) | ||
| with torch.device("meta"): | ||
| model = model_fn() | ||
|
|
||
| with AutoParallel(model, input_fn, device_mesh_2d) as autop: | ||
| autop.add_parameter_memory_constraint(low=low_mem, high=high_mem) | ||
| sharding_placement = autop.optimize_placement() | ||
| autop.apply_placement(sharding_placement) | ||
| gm = autop.parallel_gm | ||
| g = gm.graph | ||
|
|
||
| def gen_g_inputs(g): | ||
| phs = g.find_nodes(op="placeholder") | ||
| ret = [] | ||
| for ph in phs: | ||
| ft = ph.meta["val"] | ||
| t = torch.randn(ft.shape, dtype=ft.dtype, device=ft.device) | ||
| ret.append(t) | ||
| return ret | ||
|
|
||
| inputs = gen_g_inputs(g) | ||
| g_pro, g_main = split_fsdp_prefetch(g) | ||
| g_main, g_epi = split_fsdp_reduce_scatters_epilogue(g_main) | ||
|
|
||
| gm_pro = GraphModule(gm, g_pro) | ||
| gm_main = GraphModule(gm, g_main) | ||
| gm_epi = GraphModule(gm, g_epi) | ||
|
|
||
| gm(*inputs) | ||
| gm_epi(*gm_main(*gm_pro(*inputs))) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one thing I found a bit confusing when running this locally is that we are currently moving "views of inputs" into the prefetch subgraph. In my local example:
full backward graph: P2013183321
prefetch subgraph: P2013183365
remaining subgraph: P2013183449
you can see that
tangents_1is not an input in the "remaining subgraph", which is surprising because we should not be performing any FSDP collectives directly on tangents_1. And it looks like this is because there is aview(tangents_1)in the main backward, that we end up moving into the prefetch subgraph (technically harmless but confusing).