File tree Expand file tree Collapse file tree 2 files changed +59
-0
lines changed Expand file tree Collapse file tree 2 files changed +59
-0
lines changed Original file line number Diff line number Diff line change 1+ import dataclasses
2+
3+ import torch
4+ import torch .utils ._pytree as pytree
5+ from torch ._functorch ._aot_autograd .descriptors import AOTOutput
6+ from torch ._functorch .partitioners import _extract_graph_with_inputs_outputs
7+
8+
9+ @dataclasses .dataclass (frozen = True )
10+ class PrefetchOutput (AOTOutput ):
11+ pass
12+
13+
14+ def split_fsdp_prefetch (g : torch .fx .Graph ) -> tuple [torch .fx .Graph , torch .fx .Graph ]:
15+ g_ins = g .find_nodes (op = "placeholder" )
16+ prefetch_g_outs_map = {}
17+
18+ for g_in in g_ins :
19+ n = g_in
20+ while True :
21+ if len (n .users ) != 1 :
22+ break
23+ user = next (iter (n .users ))
24+ if len (user .all_input_nodes ) > 1 :
25+ break
26+ n = user
27+ prefetch_g_outs_map [g_in ] = n
28+
29+ prefetch_g_outs = list (prefetch_g_outs_map .values ())
30+ prefetch_g_outs_descs : list [AOTOutput ] = [
31+ PrefetchOutput () for _ in range (len (prefetch_g_outs ))
32+ ]
33+
34+ prefetch_g = _extract_graph_with_inputs_outputs (
35+ g ,
36+ g_ins ,
37+ prefetch_g_outs ,
38+ prefetch_g_outs_descs ,
39+ )
40+
41+ g_outs = pytree .arg_tree_leaves (* (n .args for n in g .find_nodes (op = "output" )))
42+ g_outs_descs = pytree .arg_tree_leaves (
43+ next (iter (g .find_nodes (op = "output" ))).meta .get ("desc" , [None ] * len (g_outs ))
44+ )
45+ main_g = _extract_graph_with_inputs_outputs (
46+ g ,
47+ prefetch_g_outs ,
48+ g_outs ,
49+ g_outs_descs ,
50+ )
51+ return prefetch_g , main_g
Original file line number Diff line number Diff line change @@ -253,6 +253,14 @@ def _pass(graph):
253253 print (f"Took { time .time () - t :.2f} s" )
254254 parallel_mod = autop .apply_placement (sharding_placement )
255255
256+ test_split_fsdp_prefetch = True
257+ if test_split_fsdp_prefetch :
258+ gm = autop .parallel_gm
259+ g = gm .graph
260+ from autoparallel .passes import split_fsdp_prefetch
261+
262+ prefetch_g , main_g = split_fsdp_prefetch (g )
263+
256264# run weight init on our sharded DTensor params
257265parallel_mod .to_empty (device = "cuda" )
258266parallel_mod .init_weights ()
You can’t perform that action at this time.
0 commit comments