Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 147 additions & 0 deletions autoparallel/pipeline/passes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import dataclasses
from contextlib import contextmanager
from functools import partial
from typing import Any

import torch
import torch.fx.node
import torch.utils._pytree as pytree
from torch._functorch._aot_autograd.descriptors import AOTOutput
from torch._functorch.partitioners import _extract_graph_with_inputs_outputs
from torch._inductor.fx_passes.bucketing import (
is_all_gather_into_tensor,
is_reduce_scatter_tensor,
)


@contextmanager
def exclude_from_fx_side_effectful(exclude_vals: set[Any]):
original_val = torch.fx.node._side_effectful_functions.copy()
try:
torch.fx.node._side_effectful_functions -= exclude_vals
yield
finally:
torch.fx.node._side_effectful_functions.clear()
torch.fx.node._side_effectful_functions.update(original_val)


exclude_wait_from_fx_side_effectful = partial(
exclude_from_fx_side_effectful,
{
torch.ops._c10d_functional.wait_tensor,
torch.ops._c10d_functional.wait_tensor.default,
},
)


@dataclasses.dataclass(frozen=True)
class PrefetchOutput(AOTOutput):
pass


@dataclasses.dataclass(frozen=True)
class EpilogueInput(AOTOutput):
pass


def split_fsdp_prefetch(
g: torch.fx.Graph, stop_at_all_gather: bool = True
) -> tuple[torch.fx.Graph, torch.fx.Graph]:
g_ins = g.find_nodes(op="placeholder")
prefetch_g_outs_map = []

for g_in in g_ins:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one thing I found a bit confusing when running this locally is that we are currently moving "views of inputs" into the prefetch subgraph. In my local example:

full backward graph: P2013183321
prefetch subgraph: P2013183365
remaining subgraph: P2013183449

you can see that tangents_1 is not an input in the "remaining subgraph", which is surprising because we should not be performing any FSDP collectives directly on tangents_1. And it looks like this is because there is a view(tangents_1) in the main backward, that we end up moving into the prefetch subgraph (technically harmless but confusing).

n = g_in
has_ag = False
while True:
if len(n.users) != 1:
break
Comment on lines +62 to +63
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assumes that the partitioner has been changed to that we don't recompute the all-gather collectives in the backward pass.

@sanketpurandare you'll need to keep this in mind for your PR

user = next(iter(n.users))
if len(user.all_input_nodes) > 1:
break
n = user
if stop_at_all_gather and is_all_gather_into_tensor(n):
has_ag = True
w_n = next(iter(n.users))
n = w_n
break
if stop_at_all_gather and not has_ag:
prefetch_g_outs_map.append(g_in)
else:
prefetch_g_outs_map.append(n)

prefetch_g_outs = prefetch_g_outs_map
prefetch_g_outs_descs: list[AOTOutput] = [
PrefetchOutput() for _ in range(len(prefetch_g_outs))
]
g_outs = pytree.arg_tree_leaves(*(n.args for n in g.find_nodes(op="output")))
g_outs_descs = pytree.arg_tree_leaves(
next(iter(g.find_nodes(op="output"))).meta.get("desc", [None] * len(g_outs))
)
with exclude_wait_from_fx_side_effectful():
prefetch_g = _extract_graph_with_inputs_outputs(
g,
g_ins,
prefetch_g_outs,
prefetch_g_outs_descs,
)

main_g = _extract_graph_with_inputs_outputs(
g,
prefetch_g_outs,
g_outs,
g_outs_descs,
)
return prefetch_g, main_g


def split_fsdp_reduce_scatters_epilogue(
g: torch.fx.Graph,
) -> tuple[torch.fx.Graph, torch.fx.Graph]:
g_ins = g.find_nodes(op="placeholder")
g_outs = pytree.arg_tree_leaves(*(n.args for n in g.find_nodes(op="output")))
g_outs_descs = pytree.arg_tree_leaves(
next(iter(g.find_nodes(op="output"))).meta.get("desc", [None] * len(g_outs))
)

g_outs_map = []
for g_out in g_outs:
n = g_out
has_rs = False
while n is not None:
if len(n.all_input_nodes) != 1:
break
n_in = n.all_input_nodes[0]
if len(n_in.users) > 1:
break
prev_n = n
n = n_in
if is_reduce_scatter_tensor(prev_n):
has_rs = True
break
if has_rs:
g_outs_map.append(n)
else:
g_outs_map.append(g_out)

epi_g_ins = [n for n in g_outs_map if n is not None]
epi_g_ins_descs: list[AOTOutput] = [EpilogueInput() for _ in range(len(epi_g_ins))]
main_g = _extract_graph_with_inputs_outputs(
g,
g_ins,
epi_g_ins,
epi_g_ins_descs,
)
epi_g = _extract_graph_with_inputs_outputs(
g,
epi_g_ins,
g_outs,
g_outs_descs,
)

return main_g, epi_g
100 changes: 100 additions & 0 deletions tests/test_pipeline_passes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from unittest.mock import patch

import pytest
import torch
from torch import nn
from torch.fx import GraphModule
from torch.testing._internal.distributed.fake_pg import FakeStore

from autoparallel.api import AutoParallel
from autoparallel.pipeline.passes import (
split_fsdp_prefetch,
split_fsdp_reduce_scatters_epilogue,
)


@pytest.fixture(scope="module", autouse=True)
def init_pg():
world_size = 256
fake_store = FakeStore()
if torch.distributed.is_initialized():
return
torch.distributed.init_process_group(
"fake", store=fake_store, rank=0, world_size=world_size
)


@pytest.fixture(scope="module")
def device_mesh_2d():
world_size = torch.distributed.get_world_size()
mesh = torch.distributed.device_mesh.init_device_mesh(
"cuda",
(world_size // 8, 8),
mesh_dim_names=(
"dp",
"tp",
),
)
return mesh


class FFN(nn.Module):
def __init__(self, dim1, dim2):
super().__init__()
bias = False
self.linear1 = nn.Linear(dim1, dim2, bias=bias)
self.linear2 = nn.Linear(dim2, dim1, bias=bias)

def forward(self, x, y):
return y + 2, self.linear2(self.linear1(x)), y + 2


def _make_model_and_input_fn(mesh, device="cuda"):
bs = 2048 * mesh.shape[0]
dim1 = 1024
dim2 = 4096

def model_fn():
return FFN(dim1, dim2)

def input_fn():
return torch.randn(bs, dim1).to(device), torch.randn(bs, 1).to(device)

return model_fn, input_fn


@patch("torch.cuda.device_count", lambda: 8)
@patch("torch.cuda.get_device_name", lambda device: "H100")
def test_fsdp_split_passes(device_mesh_2d):
low_mem = 0
high_mem = None
model_fn, input_fn = _make_model_and_input_fn(device_mesh_2d)
with torch.device("meta"):
model = model_fn()

with AutoParallel(model, input_fn, device_mesh_2d) as autop:
autop.add_parameter_memory_constraint(low=low_mem, high=high_mem)
sharding_placement = autop.optimize_placement()
autop.apply_placement(sharding_placement)
gm = autop.parallel_gm
g = gm.graph

def gen_g_inputs(g):
phs = g.find_nodes(op="placeholder")
ret = []
for ph in phs:
ft = ph.meta["val"]
t = torch.randn(ft.shape, dtype=ft.dtype, device=ft.device)
ret.append(t)
return ret

inputs = gen_g_inputs(g)
g_pro, g_main = split_fsdp_prefetch(g)
g_main, g_epi = split_fsdp_reduce_scatters_epilogue(g_main)

gm_pro = GraphModule(gm, g_pro)
gm_main = GraphModule(gm, g_main)
gm_epi = GraphModule(gm, g_epi)

gm(*inputs)
gm_epi(*gm_main(*gm_pro(*inputs)))
Loading