Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,4 @@ jobs:
python examples/example_dcp.py
python examples/example_local_map.py
python examples/example_ds3_local_map.py
python examples/example_pp_graph_partition.py
python examples/example_pp_graph_passes.py
12 changes: 11 additions & 1 deletion autoparallel/_passes/graph_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,15 @@ def partition_joint_with_descriptors(
fw_compiler: Callable = boxed_nop_preserve_node_meta,
bw_compiler: Callable = boxed_nop_preserve_node_meta,
) -> tuple[
torch.fx.GraphModule, torch.fx.GraphModule, int, int, int, int, list[int], list[Any]
torch.fx.GraphModule,
torch.fx.GraphModule,
int,
int,
int,
int,
int,
list[int],
list[Any],
]:
aot_state: AOTState = jd._aot_state
aot_graph_capture: AOTGraphCapture = jd._aot_graph_capture
Expand Down Expand Up @@ -79,9 +87,11 @@ def partition_joint_with_descriptors(
num_mutate_inputs = len(
[x for x in fw_metadata.input_info if x.mutates_data or x.mutates_metadata]
)
num_params_buffers = aot_config.num_params_buffers
return (
fw_module,
bw_module,
num_params_buffers,
num_user_outputs,
num_mutate_inputs,
num_fw_outs_saved_for_bw,
Expand Down
2 changes: 1 addition & 1 deletion autoparallel/_passes/split_di_dw_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def _extract_fwd_bwd_modules(

# TODO: in theory we can infer num_weight_gradients from the graph metadata directly
def split_di_dw_graph(
bw_gm_old: fx.GraphModule, *, num_weight_gradients
bw_gm_old: fx.GraphModule, *, num_weight_gradients: int
) -> tuple[fx.GraphModule, fx.GraphModule, int]:
# we could consider doing this is a non-mutating way
bw_gm = copy.deepcopy(bw_gm_old)
Expand Down
65 changes: 41 additions & 24 deletions autoparallel/_passes/split_fsdp_collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import dataclasses
from contextlib import contextmanager
from copy import deepcopy
from functools import partial
from typing import Any

Expand Down Expand Up @@ -49,12 +50,19 @@ class EpilogueInput(AOTOutput):
pass


def split_fsdp_prefetch(g: torch.fx.Graph) -> tuple[torch.fx.Graph, torch.fx.Graph]:
g_ins = g.find_nodes(op="placeholder")
def split_fsdp_prefetch(
gm: torch.fx.GraphModule,
num_params: int,
Copy link
Contributor

@IvanKobzarev IvanKobzarev Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: If we use export with descriptors, potentially num_params could be taken from metadata.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, for now this is easily obtainable from graph meta. Same as above can be removed later.

) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule]:
g = deepcopy(gm.graph)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious, why do you want to keep the original graph unchanged? Will it be further used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we are using its container graph module to initialize it with the two new graphs, I just thought it would be safer to this way, we can remove this later

all_g_ins = g.find_nodes(op="placeholder")
param_g_ins = all_g_ins[:num_params]
rem_g_ins = all_g_ins[num_params:]

prefetch_g_outs_map = []

for g_in in g_ins:
n = g_in
for param_g_in in param_g_ins:
n = param_g_in
last_ag = None
while True:
if len(n.users) != 1:
Expand All @@ -66,7 +74,7 @@ def split_fsdp_prefetch(g: torch.fx.Graph) -> tuple[torch.fx.Graph, torch.fx.Gra
if is_all_gather_into_tensor(n):
last_ag = n
if last_ag is None:
prefetch_g_outs_map.append(g_in)
prefetch_g_outs_map.append(param_g_in)
else:
w_n = next(iter(last_ag.users))
prefetch_g_outs_map.append(w_n)
Expand All @@ -82,34 +90,42 @@ def split_fsdp_prefetch(g: torch.fx.Graph) -> tuple[torch.fx.Graph, torch.fx.Gra
with exclude_wait_from_fx_side_effectful():
prefetch_g = _extract_graph_with_inputs_outputs(
g,
g_ins,
param_g_ins,
prefetch_g_outs,
prefetch_g_outs_descs,
ignore_must_be_in_fw_bw=True,
)

main_g = _extract_graph_with_inputs_outputs(
g,
prefetch_g_outs,
prefetch_g_outs + rem_g_ins,
g_outs,
g_outs_descs,
ignore_must_be_in_fw_bw=True,
)
return prefetch_g, main_g
prefetch_gm = torch.fx._lazy_graph_module._make_graph_module(gm, prefetch_g)
main_gm = torch.fx._lazy_graph_module._make_graph_module(gm, main_g)
return prefetch_gm, main_gm


def split_fsdp_reduce_scatters_epilogue(
g: torch.fx.Graph,
) -> tuple[torch.fx.Graph, torch.fx.Graph]:
gm: torch.fx.GraphModule,
num_grads: int,
) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule]:
g = deepcopy(gm.graph)
g_ins = g.find_nodes(op="placeholder")
g_outs = pytree.arg_tree_leaves(*(n.args for n in g.find_nodes(op="output")))
g_outs_descs = pytree.arg_tree_leaves(
next(iter(g.find_nodes(op="output"))).meta.get("desc", [None] * len(g_outs))
grad_outs = g_outs[:num_grads]
rem_g_outs = g_outs[num_grads:]
out_descs = pytree.arg_tree_leaves(
next(iter(g.find_nodes(op="output"))).meta.get("desc", [None] * len(grad_outs))
)
grad_outs_descs = out_descs[:num_grads]
rem_g_outs_descs = out_descs[num_grads:]

g_outs_map = []
for g_out in g_outs:
n = g_out
grad_outs_map = []
for grad_out in grad_outs:
n = grad_out
last_rs = None
while n is not None:
if len(n.all_input_nodes) != 1:
Expand All @@ -124,27 +140,28 @@ def split_fsdp_reduce_scatters_epilogue(
# The reduction of gradients happen in multiple steps
last_rs = n
if last_rs is not None:
g_outs_map.append(last_rs)
grad_outs_map.append(last_rs)
else:
g_outs_map.append(g_out)
grad_outs_map.append(grad_out)

epi_g_ins = [n for n in g_outs_map if n is not None]
epi_g_ins = grad_outs_map
epi_g_ins_descs: list[AOTOutput] = [EpilogueInput() for _ in range(len(epi_g_ins))]

with exclude_wait_from_fx_side_effectful():
main_g = _extract_graph_with_inputs_outputs(
g,
g_ins,
epi_g_ins,
epi_g_ins_descs,
epi_g_ins + rem_g_outs,
epi_g_ins_descs + rem_g_outs_descs,
ignore_must_be_in_fw_bw=True,
)
epi_g = _extract_graph_with_inputs_outputs(
g,
epi_g_ins,
g_outs,
g_outs_descs,
grad_outs,
grad_outs_descs,
ignore_must_be_in_fw_bw=True,
)

return main_g, epi_g
epi_gm = torch.fx._lazy_graph_module._make_graph_module(gm, epi_g)
main_gm = torch.fx._lazy_graph_module._make_graph_module(gm, main_g)
return main_gm, epi_gm
92 changes: 79 additions & 13 deletions autoparallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,23 +588,33 @@ def forward(self, *args):

class AutoParallelPP(AutoParallel):
def apply_placement_pp(
self, sharding_placement=None, generate_di_dw_split_graphs=False
self, sharding_placement=None, graph_passes: list[str] = []
) -> dict[str, Any]:
assert all(
g_pass in ["split_fsdp_collectives", "split_dI_dW"]
for g_pass in graph_passes
), "Only split_fsdp_collectives and split_dI_dW_graph are supported"
sharded_param_dict, sharded_buffer_dict = self._apply_placement_common(
sharding_placement
)
num_params = len(sharded_param_dict)
num_buffers = len(sharded_buffer_dict)
(
fw_module,
bw_module,
num_params_buffers,
num_user_outputs,
num_mutate_inputs,
num_fw_outs_saved_for_bw,
num_symints_saved_for_bw,
_indices_of_inps_to_detach,
adjusted_flat_args,
) = partition_joint_with_descriptors(self.joint_with_descriptors)

assert num_params_buffers == (
num_params + num_buffers
), f"num_params_buffers: {num_params_buffers}, num_params: {num_params}, num_buffers: {num_buffers}"
print(
f"num_params_buffers: {num_params_buffers}\n"
f"num_user_outputs: {num_user_outputs}\n"
f"num_mutate_inputs: {num_mutate_inputs}\n"
f"num_fw_outs_saved_for_bw: {num_fw_outs_saved_for_bw}\n"
Expand All @@ -631,14 +641,71 @@ def apply_placement_pp(
print_output=False, include_stride=True, include_device=True
),
)
if generate_di_dw_split_graphs:
from autoparallel._passes.split_di_dw_graph import split_di_dw_graph
unshard_module: Optional[torch.fx.GraphModule] = None
reduce_grad_module: Optional[torch.fx.GraphModule] = None
if "split_fsdp_collectives" in graph_passes:
assert (
not self.reshard_after_forward
), "reshard_after_forward should be False to disable FSDP all_gather in the backward pass"
from autoparallel._passes.split_fsdp_collectives import (
split_fsdp_prefetch,
split_fsdp_reduce_scatters_epilogue,
)

num_weight_gradients = (
self.joint_with_descriptors._aot_state.aot_config.num_params_buffers
unshard_module, fw_module = split_fsdp_prefetch(fw_module, num_params)
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "autoparallel_pp_unshard_graph",
"encoding": "string",
},
payload_fn=lambda: unshard_module.print_readable(
print_output=False, include_stride=True, include_device=True
),
)
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "autoparallel_pp_fwd_no_fsdp_graph",
"encoding": "string",
},
payload_fn=lambda: fw_module.print_readable(
print_output=False, include_stride=True, include_device=True
),
)
bw_module, reduce_grad_module = split_fsdp_reduce_scatters_epilogue(
bw_module, num_params
)
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "autoparallel_pp_bwd_no_fsdp_graph",
"encoding": "string",
},
payload_fn=lambda: bw_module.print_readable(
print_output=False, include_stride=True, include_device=True
),
)
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "autoparallel_pp_reduce_grad_graph",
"encoding": "string",
},
payload_fn=lambda: reduce_grad_module.print_readable(
print_output=False, include_stride=True, include_device=True
),
)

bw_dI_module: Optional[torch.fx.GraphModule] = None
bw_dW_module: Optional[torch.fx.GraphModule] = None
num_input_grads = 0
if "split_dI_dW" in graph_passes:
from autoparallel._passes.split_di_dw_graph import split_di_dw_graph

bw_dI_module, bw_dW_module, num_input_grads = split_di_dw_graph(
bw_module, num_weight_gradients=num_weight_gradients
bw_module,
num_weight_gradients=num_params_buffers,
)
trace_structured(
"artifact",
Expand Down Expand Up @@ -669,24 +736,23 @@ def apply_placement_pp(
raise RuntimeError(
"attempted to run split dI/dW pass on a graph that has no input gradients"
)
else:
bw_dI_module, bw_dW_module, num_input_grads = None, None, -1

graph_meta: dict[str, int] = {
"num_mutate_inputs": num_mutate_inputs,
"num_user_outputs": num_user_outputs,
"num_symints_saved_for_bw": num_symints_saved_for_bw,
"num_weight_buffer_grads": len(sharded_param_dict)
+ len(sharded_buffer_dict),
"num_params": num_params,
"num_buffers": num_buffers,
"num_input_grads": num_input_grads,
}

graph_modules: dict[str, Optional[torch.fx.GraphModule]] = {
"fw": fw_module,
"full_bw": bw_module,
"bw_dI": bw_dI_module,
"bw_dW": bw_dW_module,
"unshard": None,
"reduce_grad": None,
"unshard": unshard_module,
"reduce_grad": reduce_grad_module,
}
self.parallel_model = AutoParallelPPModule(
sharded_param_dict,
Expand Down
Loading