Skip to content

Commit 19ac5cb

Browse files
author
Sanket Jayant Purandare
committed
Enabling split_dI_dW and split_fsdp_collectives passes
1 parent 9e86bcc commit 19ac5cb

File tree

8 files changed

+604
-56
lines changed

8 files changed

+604
-56
lines changed

.github/workflows/test_cuda.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,4 @@ jobs:
4646
python examples/example_dcp.py
4747
python examples/example_local_map.py
4848
python examples/example_ds3_local_map.py
49-
python examples/example_pp_graph_partition.py
49+
python examples/example_pp_graph_passes.py

autoparallel/_passes/graph_partition.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,15 @@ def partition_joint_with_descriptors(
2929
fw_compiler: Callable = boxed_nop_preserve_node_meta,
3030
bw_compiler: Callable = boxed_nop_preserve_node_meta,
3131
) -> tuple[
32-
torch.fx.GraphModule, torch.fx.GraphModule, int, int, int, int, list[int], list[Any]
32+
torch.fx.GraphModule,
33+
torch.fx.GraphModule,
34+
int,
35+
int,
36+
int,
37+
int,
38+
int,
39+
list[int],
40+
list[Any],
3341
]:
3442
aot_state: AOTState = jd._aot_state
3543
aot_graph_capture: AOTGraphCapture = jd._aot_graph_capture
@@ -79,9 +87,11 @@ def partition_joint_with_descriptors(
7987
num_mutate_inputs = len(
8088
[x for x in fw_metadata.input_info if x.mutates_data or x.mutates_metadata]
8189
)
90+
num_params_buffers = aot_config.num_params_buffers
8291
return (
8392
fw_module,
8493
bw_module,
94+
num_params_buffers,
8595
num_user_outputs,
8696
num_mutate_inputs,
8797
num_fw_outs_saved_for_bw,

autoparallel/_passes/split_di_dw_graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def _extract_fwd_bwd_modules(
191191

192192
# TODO: in theory we can infer num_weight_gradients from the graph metadata directly
193193
def split_di_dw_graph(
194-
bw_gm_old: fx.GraphModule, *, num_weight_gradients
194+
bw_gm_old: fx.GraphModule, *, num_weight_gradients: int
195195
) -> tuple[fx.GraphModule, fx.GraphModule, int]:
196196
# we could consider doing this is a non-mutating way
197197
bw_gm = copy.deepcopy(bw_gm_old)

autoparallel/_passes/split_fsdp_collectives.py

Lines changed: 41 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import dataclasses
77
from contextlib import contextmanager
8+
from copy import deepcopy
89
from functools import partial
910
from typing import Any
1011

@@ -49,12 +50,19 @@ class EpilogueInput(AOTOutput):
4950
pass
5051

5152

52-
def split_fsdp_prefetch(g: torch.fx.Graph) -> tuple[torch.fx.Graph, torch.fx.Graph]:
53-
g_ins = g.find_nodes(op="placeholder")
53+
def split_fsdp_prefetch(
54+
gm: torch.fx.GraphModule,
55+
num_params: int,
56+
) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule]:
57+
g = deepcopy(gm.graph)
58+
all_g_ins = g.find_nodes(op="placeholder")
59+
param_g_ins = all_g_ins[:num_params]
60+
rem_g_ins = all_g_ins[num_params:]
61+
5462
prefetch_g_outs_map = []
5563

56-
for g_in in g_ins:
57-
n = g_in
64+
for param_g_in in param_g_ins:
65+
n = param_g_in
5866
last_ag = None
5967
while True:
6068
if len(n.users) != 1:
@@ -66,7 +74,7 @@ def split_fsdp_prefetch(g: torch.fx.Graph) -> tuple[torch.fx.Graph, torch.fx.Gra
6674
if is_all_gather_into_tensor(n):
6775
last_ag = n
6876
if last_ag is None:
69-
prefetch_g_outs_map.append(g_in)
77+
prefetch_g_outs_map.append(param_g_in)
7078
else:
7179
w_n = next(iter(last_ag.users))
7280
prefetch_g_outs_map.append(w_n)
@@ -82,34 +90,42 @@ def split_fsdp_prefetch(g: torch.fx.Graph) -> tuple[torch.fx.Graph, torch.fx.Gra
8290
with exclude_wait_from_fx_side_effectful():
8391
prefetch_g = _extract_graph_with_inputs_outputs(
8492
g,
85-
g_ins,
93+
param_g_ins,
8694
prefetch_g_outs,
8795
prefetch_g_outs_descs,
8896
ignore_must_be_in_fw_bw=True,
8997
)
9098

9199
main_g = _extract_graph_with_inputs_outputs(
92100
g,
93-
prefetch_g_outs,
101+
prefetch_g_outs + rem_g_ins,
94102
g_outs,
95103
g_outs_descs,
96104
ignore_must_be_in_fw_bw=True,
97105
)
98-
return prefetch_g, main_g
106+
prefetch_gm = torch.fx._lazy_graph_module._make_graph_module(gm, prefetch_g)
107+
main_gm = torch.fx._lazy_graph_module._make_graph_module(gm, main_g)
108+
return prefetch_gm, main_gm
99109

100110

101111
def split_fsdp_reduce_scatters_epilogue(
102-
g: torch.fx.Graph,
103-
) -> tuple[torch.fx.Graph, torch.fx.Graph]:
112+
gm: torch.fx.GraphModule,
113+
num_grads: int,
114+
) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule]:
115+
g = deepcopy(gm.graph)
104116
g_ins = g.find_nodes(op="placeholder")
105117
g_outs = pytree.arg_tree_leaves(*(n.args for n in g.find_nodes(op="output")))
106-
g_outs_descs = pytree.arg_tree_leaves(
107-
next(iter(g.find_nodes(op="output"))).meta.get("desc", [None] * len(g_outs))
118+
grad_outs = g_outs[:num_grads]
119+
rem_g_outs = g_outs[num_grads:]
120+
out_descs = pytree.arg_tree_leaves(
121+
next(iter(g.find_nodes(op="output"))).meta.get("desc", [None] * len(grad_outs))
108122
)
123+
grad_outs_descs = out_descs[:num_grads]
124+
rem_g_outs_descs = out_descs[num_grads:]
109125

110-
g_outs_map = []
111-
for g_out in g_outs:
112-
n = g_out
126+
grad_outs_map = []
127+
for grad_out in grad_outs:
128+
n = grad_out
113129
last_rs = None
114130
while n is not None:
115131
if len(n.all_input_nodes) != 1:
@@ -124,27 +140,28 @@ def split_fsdp_reduce_scatters_epilogue(
124140
# The reduction of gradients happen in multiple steps
125141
last_rs = n
126142
if last_rs is not None:
127-
g_outs_map.append(last_rs)
143+
grad_outs_map.append(last_rs)
128144
else:
129-
g_outs_map.append(g_out)
145+
grad_outs_map.append(grad_out)
130146

131-
epi_g_ins = [n for n in g_outs_map if n is not None]
147+
epi_g_ins = grad_outs_map
132148
epi_g_ins_descs: list[AOTOutput] = [EpilogueInput() for _ in range(len(epi_g_ins))]
133149

134150
with exclude_wait_from_fx_side_effectful():
135151
main_g = _extract_graph_with_inputs_outputs(
136152
g,
137153
g_ins,
138-
epi_g_ins,
139-
epi_g_ins_descs,
154+
epi_g_ins + rem_g_outs,
155+
epi_g_ins_descs + rem_g_outs_descs,
140156
ignore_must_be_in_fw_bw=True,
141157
)
142158
epi_g = _extract_graph_with_inputs_outputs(
143159
g,
144160
epi_g_ins,
145-
g_outs,
146-
g_outs_descs,
161+
grad_outs,
162+
grad_outs_descs,
147163
ignore_must_be_in_fw_bw=True,
148164
)
149-
150-
return main_g, epi_g
165+
epi_gm = torch.fx._lazy_graph_module._make_graph_module(gm, epi_g)
166+
main_gm = torch.fx._lazy_graph_module._make_graph_module(gm, main_g)
167+
return main_gm, epi_gm

autoparallel/api.py

Lines changed: 79 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -588,23 +588,33 @@ def forward(self, *args):
588588

589589
class AutoParallelPP(AutoParallel):
590590
def apply_placement_pp(
591-
self, sharding_placement=None, generate_di_dw_split_graphs=False
591+
self, sharding_placement=None, graph_passes: list[str] = []
592592
) -> dict[str, Any]:
593+
assert all(
594+
g_pass in ["split_fsdp_collectives", "split_dI_dW"]
595+
for g_pass in graph_passes
596+
), "Only split_fsdp_collectives and split_dI_dW_graph are supported"
593597
sharded_param_dict, sharded_buffer_dict = self._apply_placement_common(
594598
sharding_placement
595599
)
600+
num_params = len(sharded_param_dict)
601+
num_buffers = len(sharded_buffer_dict)
596602
(
597603
fw_module,
598604
bw_module,
605+
num_params_buffers,
599606
num_user_outputs,
600607
num_mutate_inputs,
601608
num_fw_outs_saved_for_bw,
602609
num_symints_saved_for_bw,
603610
_indices_of_inps_to_detach,
604611
adjusted_flat_args,
605612
) = partition_joint_with_descriptors(self.joint_with_descriptors)
606-
613+
assert num_params_buffers == (
614+
num_params + num_buffers
615+
), f"num_params_buffers: {num_params_buffers}, num_params: {num_params}, num_buffers: {num_buffers}"
607616
print(
617+
f"num_params_buffers: {num_params_buffers}\n"
608618
f"num_user_outputs: {num_user_outputs}\n"
609619
f"num_mutate_inputs: {num_mutate_inputs}\n"
610620
f"num_fw_outs_saved_for_bw: {num_fw_outs_saved_for_bw}\n"
@@ -631,14 +641,71 @@ def apply_placement_pp(
631641
print_output=False, include_stride=True, include_device=True
632642
),
633643
)
634-
if generate_di_dw_split_graphs:
635-
from autoparallel._passes.split_di_dw_graph import split_di_dw_graph
644+
unshard_module: Optional[torch.fx.GraphModule] = None
645+
reduce_grad_module: Optional[torch.fx.GraphModule] = None
646+
if "split_fsdp_collectives" in graph_passes:
647+
assert (
648+
not self.reshard_after_forward
649+
), "reshard_after_forward should be False to disable FSDP all_gather in the backward pass"
650+
from autoparallel._passes.split_fsdp_collectives import (
651+
split_fsdp_prefetch,
652+
split_fsdp_reduce_scatters_epilogue,
653+
)
636654

637-
num_weight_gradients = (
638-
self.joint_with_descriptors._aot_state.aot_config.num_params_buffers
655+
unshard_module, fw_module = split_fsdp_prefetch(fw_module, num_params)
656+
trace_structured(
657+
"artifact",
658+
metadata_fn=lambda: {
659+
"name": "autoparallel_pp_unshard_graph",
660+
"encoding": "string",
661+
},
662+
payload_fn=lambda: unshard_module.print_readable(
663+
print_output=False, include_stride=True, include_device=True
664+
),
639665
)
666+
trace_structured(
667+
"artifact",
668+
metadata_fn=lambda: {
669+
"name": "autoparallel_pp_fwd_no_fsdp_graph",
670+
"encoding": "string",
671+
},
672+
payload_fn=lambda: fw_module.print_readable(
673+
print_output=False, include_stride=True, include_device=True
674+
),
675+
)
676+
bw_module, reduce_grad_module = split_fsdp_reduce_scatters_epilogue(
677+
bw_module, num_params
678+
)
679+
trace_structured(
680+
"artifact",
681+
metadata_fn=lambda: {
682+
"name": "autoparallel_pp_bwd_no_fsdp_graph",
683+
"encoding": "string",
684+
},
685+
payload_fn=lambda: bw_module.print_readable(
686+
print_output=False, include_stride=True, include_device=True
687+
),
688+
)
689+
trace_structured(
690+
"artifact",
691+
metadata_fn=lambda: {
692+
"name": "autoparallel_pp_reduce_grad_graph",
693+
"encoding": "string",
694+
},
695+
payload_fn=lambda: reduce_grad_module.print_readable(
696+
print_output=False, include_stride=True, include_device=True
697+
),
698+
)
699+
700+
bw_dI_module: Optional[torch.fx.GraphModule] = None
701+
bw_dW_module: Optional[torch.fx.GraphModule] = None
702+
num_input_grads = 0
703+
if "split_dI_dW" in graph_passes:
704+
from autoparallel._passes.split_di_dw_graph import split_di_dw_graph
705+
640706
bw_dI_module, bw_dW_module, num_input_grads = split_di_dw_graph(
641-
bw_module, num_weight_gradients=num_weight_gradients
707+
bw_module,
708+
num_weight_gradients=num_params_buffers,
642709
)
643710
trace_structured(
644711
"artifact",
@@ -669,24 +736,23 @@ def apply_placement_pp(
669736
raise RuntimeError(
670737
"attempted to run split dI/dW pass on a graph that has no input gradients"
671738
)
672-
else:
673-
bw_dI_module, bw_dW_module, num_input_grads = None, None, -1
674739

675740
graph_meta: dict[str, int] = {
676741
"num_mutate_inputs": num_mutate_inputs,
677742
"num_user_outputs": num_user_outputs,
678743
"num_symints_saved_for_bw": num_symints_saved_for_bw,
679-
"num_weight_buffer_grads": len(sharded_param_dict)
680-
+ len(sharded_buffer_dict),
744+
"num_params": num_params,
745+
"num_buffers": num_buffers,
681746
"num_input_grads": num_input_grads,
682747
}
748+
683749
graph_modules: dict[str, Optional[torch.fx.GraphModule]] = {
684750
"fw": fw_module,
685751
"full_bw": bw_module,
686752
"bw_dI": bw_dI_module,
687753
"bw_dW": bw_dW_module,
688-
"unshard": None,
689-
"reduce_grad": None,
754+
"unshard": unshard_module,
755+
"reduce_grad": reduce_grad_module,
690756
}
691757
self.parallel_model = AutoParallelPPModule(
692758
sharded_param_dict,

0 commit comments

Comments
 (0)