Skip to content

Commit fe8b82e

Browse files
committed
Update
[ghstack-poisoned]
1 parent 8b6f11c commit fe8b82e

File tree

2 files changed

+276
-1
lines changed

2 files changed

+276
-1
lines changed

autoparallel/auto_bucketing.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import torch
77

8-
from .autobucketing_util import bucket_func, bucket_plan, bucket_utils
8+
from .autobucketing_util import bucket_func, bucket_plan, bucket_utils, reorder
99

1010

1111
class simplefsdp_autobucketing_config:
@@ -71,4 +71,21 @@ def simple_fsdp_autobucketing_reordering_pass(
7171
reduce_scatter_plan,
7272
bucketable_nodes,
7373
)
74+
75+
if configs.enable_reorder_ir:
76+
print("Reorder scheduler nodes with autobucketing algroithm")
77+
node_length = len(snodes)
78+
snodes = reorder.reorder_all_gather(
79+
snodes,
80+
bucketable_nodes,
81+
all_gather_before_last_wait=True
82+
)
83+
assert node_length == len(snodes), (
84+
f"Missed nodes in reordering all gather: expected {node_length}, but got {len(snodes)}"
85+
)
86+
snodes = reorder.reorder_reduce_scatter(snodes, bucketable_nodes)
87+
assert node_length == len(snodes), (
88+
f"Missed nodes in reordering reduce scatter: expected {node_length}, but got {len(snodes)}"
89+
)
90+
7491
return snodes
Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
#
3+
# This source code is licensed under the BSD license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# mypy: ignore-errors
7+
from collections import defaultdict
8+
from enum import IntEnum
9+
from typing import Dict, List, Optional, Tuple
10+
11+
import torch
12+
from torch.utils._ordered_set import OrderedSet
13+
from torch._inductor import ir, scheduler
14+
from torch._inductor.utils import is_collective
15+
16+
from .bucket_utils import check_ir_node_bucketable
17+
18+
19+
class NodeType(IntEnum):
20+
ALL_GATHER = 0
21+
COMPUTE = 1
22+
REDUCE_SCATTER = 2
23+
AG_WAIT = 3
24+
RS_WAIT = 4
25+
26+
27+
def compute_node_users(
28+
snodes: List["scheduler.BaseSchedulerNode"],
29+
) -> Tuple[
30+
Dict["scheduler.BaseSchedulerNode", OrderedSet["scheduler.BaseSchedulerNode"]],
31+
Dict["scheduler.BaseSchedulerNode", OrderedSet["scheduler.BaseSchedulerNode"]],
32+
]:
33+
"""
34+
Compute the inverse users and users of each node
35+
"""
36+
buf_to_snode: Dict[str, scheduler.BaseSchedulerNode] = {}
37+
for node in snodes:
38+
if isinstance(node, scheduler.FusedSchedulerNode):
39+
for x in node.snodes:
40+
for buf in x.get_outputs():
41+
buf_to_snode[buf.get_name()] = node
42+
43+
for buf in node.get_outputs():
44+
buf_to_snode[buf.get_name()] = node
45+
46+
inverse_users = {}
47+
keys = list(buf_to_snode.keys())
48+
for node in snodes:
49+
dep_list = []
50+
for dep in node.unmet_dependencies:
51+
if dep.name in keys:
52+
dep_list.append(buf_to_snode[dep.name])
53+
inverse_users.update({node: OrderedSet(dep_list)})
54+
55+
node_users: Dict[
56+
scheduler.BaseSchedulerNode, OrderedSet[scheduler.BaseSchedulerNode]
57+
] = defaultdict(OrderedSet)
58+
for node, node_inverse_users in inverse_users.items():
59+
for inverse_user in node_inverse_users:
60+
node_users[inverse_user].add(node)
61+
62+
return inverse_users, node_users
63+
64+
65+
def _get_ir_node_type(ir_node: "ir.Operation", bucketable_ir_nodes) -> NodeType:
66+
"""
67+
Determine the type of a ir node
68+
"""
69+
if isinstance(ir_node, ir._WaitKernel):
70+
# Determine if the wait node is waiting for ALL_GATHER or REDUCE_SCATTER
71+
ir_op_overload = getattr(ir_node.inputs[0], "op_overload", None)
72+
if (
73+
ir_op_overload == torch.ops._c10d_functional.all_gather_into_tensor.default
74+
and check_ir_node_bucketable(ir_node.inputs[0], bucketable_ir_nodes)
75+
):
76+
return NodeType.AG_WAIT
77+
elif (
78+
ir_op_overload == torch.ops._c10d_functional.reduce_scatter_tensor.default
79+
and check_ir_node_bucketable(ir_node.inputs[0], bucketable_ir_nodes)
80+
):
81+
return NodeType.RS_WAIT
82+
if isinstance(ir_node, ir._CollectiveKernel):
83+
# Determine if the collective kernel is for ALL_GATHER or REDUCE_SCATTER
84+
ir_op_overload = getattr(ir_node, "op_overload", None)
85+
if is_collective(
86+
ir_node, op=torch.ops._c10d_functional.all_gather_into_tensor.default
87+
) and check_ir_node_bucketable(ir_node, bucketable_ir_nodes):
88+
return NodeType.ALL_GATHER
89+
elif is_collective(
90+
ir_node, op=torch.ops._c10d_functional.reduce_scatter_tensor.default
91+
) and check_ir_node_bucketable(ir_node, bucketable_ir_nodes):
92+
return NodeType.REDUCE_SCATTER
93+
94+
if isinstance(ir_node, ir.FallbackKernel):
95+
python_kernel_name = ir_node.python_kernel_name
96+
if (
97+
python_kernel_name == "torch.ops._c10d_functional.wait_tensor.default"
98+
and check_ir_node_bucketable(ir_node, bucketable_ir_nodes)
99+
):
100+
inputs_rs_kernel_name1 = (
101+
getattr(ir_node.inputs[0], "python_kernel_name", "")
102+
== "torch.ops._c10d_functional.reduce_scatter_tensor.default"
103+
)
104+
inputs_rs_kernel_name2 = (
105+
hasattr(ir_node.inputs[0], "inputs")
106+
and getattr(ir_node.inputs[0].inputs[0], "python_kernel_name", "")
107+
== "torch.ops._c10d_functional.reduce_scatter_tensor.default"
108+
)
109+
if inputs_rs_kernel_name1 or inputs_rs_kernel_name2:
110+
return NodeType.RS_WAIT
111+
112+
inputs_ag_kernel_name1 = (
113+
getattr(ir_node.inputs[0], "python_kernel_name", "")
114+
== "torch.ops._c10d_functional.all_gather_into_tensor_out.default"
115+
)
116+
inputs_ag_kernel_name2 = (
117+
hasattr(ir_node.inputs[0], "inputs")
118+
and getattr(ir_node.inputs[0].inputs[0], "python_kernel_name", "")
119+
== "torch.ops._c10d_functional.all_gather_into_tensor_out.default"
120+
)
121+
if inputs_ag_kernel_name1 or inputs_ag_kernel_name2:
122+
return NodeType.AG_WAIT
123+
elif (
124+
python_kernel_name
125+
== "torch.ops._c10d_functional.reduce_scatter_tensor.default"
126+
) and check_ir_node_bucketable(ir_node, bucketable_ir_nodes):
127+
return NodeType.REDUCE_SCATTER
128+
elif (
129+
python_kernel_name
130+
== "torch.ops._c10d_functional.all_gather_into_tensor_out.default"
131+
) and check_ir_node_bucketable(ir_node, bucketable_ir_nodes):
132+
return NodeType.ALL_GATHER
133+
return NodeType.COMPUTE
134+
135+
136+
def get_node_type(node: "scheduler.BaseSchedulerNode", bucketable_ir_nodes) -> NodeType:
137+
"""
138+
Determine the NodeType of a node
139+
"""
140+
if isinstance(node, scheduler.FusedSchedulerNode):
141+
# Only compute nodes are fused
142+
return NodeType.COMPUTE
143+
144+
if isinstance(node, scheduler.GroupedSchedulerNode):
145+
# [Only for bucketing]: newly created AG and RS are grouped as GroupedSchedulerNode
146+
child_nodes_type = [_get_ir_node_type(n.node, bucketable_ir_nodes) for n in node.snodes]
147+
if NodeType.AG_WAIT in child_nodes_type:
148+
return NodeType.AG_WAIT
149+
elif NodeType.RS_WAIT in child_nodes_type:
150+
return NodeType.RS_WAIT
151+
elif NodeType.ALL_GATHER in child_nodes_type:
152+
return NodeType.ALL_GATHER
153+
elif NodeType.REDUCE_SCATTER in child_nodes_type:
154+
return NodeType.REDUCE_SCATTER
155+
else:
156+
return NodeType.COMPUTE
157+
158+
return _get_ir_node_type(node.node, bucketable_ir_nodes)
159+
160+
161+
def reorder_all_gather(
162+
snodes: List["scheduler.BaseSchedulerNode"],
163+
bucketable_ir_nodes: set[str],
164+
all_gather_before_last_wait: Optional[bool] = True,
165+
) -> List["scheduler.BaseSchedulerNode"]:
166+
"""
167+
Reorder All Gather and Wait in the forward/backward pass;
168+
1. all_gather_before_last_wait set to True: all_gather_i is reordered before wait_i-1
169+
2. all_gather_before_last_wait set to False: all_gather_i is reordered after wait_i-1
170+
"""
171+
result_list: List[scheduler.BaseSchedulerNode] = []
172+
all_gather_list: List[scheduler.BaseSchedulerNode] = []
173+
node_to_type: Dict[scheduler.BaseSchedulerNode, int] = {}
174+
inverse_users, node_users = compute_node_users(snodes)
175+
176+
for node in snodes:
177+
node_to_type[node] = get_node_type(node, bucketable_ir_nodes)
178+
snodes.reverse()
179+
for idx, node in enumerate(snodes):
180+
node_type = node_to_type[node]
181+
if node_type in [NodeType.REDUCE_SCATTER, NodeType.COMPUTE, NodeType.RS_WAIT]:
182+
# we do not reorder reduce scatter and compute node
183+
if node not in result_list and node not in all_gather_list:
184+
result_list.append(node)
185+
elif node_type == NodeType.ALL_GATHER:
186+
# gather i-th all gather node and its dependencies
187+
all_gather_list.append(node)
188+
inverse_user = list(inverse_users[node])
189+
inverse_user = [
190+
n for n in inverse_user if node_to_type[n] == NodeType.COMPUTE
191+
]
192+
if len(inverse_user) > 0:
193+
all_gather_list.extend(inverse_user)
194+
elif node_type == NodeType.AG_WAIT:
195+
if not all_gather_before_last_wait and len(all_gather_list) > 0:
196+
assert node_to_type[snodes[idx + 1]] == NodeType.ALL_GATHER
197+
# move i-th all gather node and its dependencies after (i-1)-th wait node (bc this is a reverse list)
198+
result_list.extend(all_gather_list)
199+
all_gather_list = []
200+
201+
result_list.append(node)
202+
203+
if all_gather_before_last_wait and len(all_gather_list) > 0:
204+
assert node_to_type[snodes[idx + 1]] == NodeType.ALL_GATHER
205+
# move i-th all gather node and its dependencies before (i-1)-th wait node (bc this is a reverse list)
206+
result_list.extend(all_gather_list)
207+
all_gather_list = []
208+
if len(all_gather_list) > 0:
209+
result_list.extend(all_gather_list)
210+
result_list.reverse()
211+
212+
return result_list
213+
214+
215+
def reorder_reduce_scatter(
216+
snodes: List["scheduler.BaseSchedulerNode"],
217+
bucketable_ir_nodes: set[str],
218+
) -> List["scheduler.BaseSchedulerNode"]:
219+
"""
220+
Reorder Reduce Scatter and Wait in the backward pass
221+
reorder wait_i_rs before reduce_scatter_i+1
222+
"""
223+
result_list: List[scheduler.BaseSchedulerNode] = []
224+
wait_list: List[scheduler.BaseSchedulerNode] = []
225+
node_to_type: Dict[scheduler.BaseSchedulerNode, int] = {}
226+
inverse_users, node_users = compute_node_users(snodes)
227+
types = []
228+
for node in snodes:
229+
node_to_type[node] = get_node_type(node, bucketable_ir_nodes)
230+
types.append(get_node_type(node, bucketable_ir_nodes))
231+
232+
if NodeType.REDUCE_SCATTER not in types:
233+
return snodes
234+
235+
for idx, node in enumerate(snodes):
236+
node_type = node_to_type[node]
237+
if node_type in [NodeType.ALL_GATHER, NodeType.COMPUTE, NodeType.AG_WAIT]:
238+
if node not in result_list and node not in wait_list:
239+
result_list.append(node)
240+
elif node_type == NodeType.RS_WAIT:
241+
# there will sometimes be a memory checker node between rs and rs wait
242+
assert node_to_type[snodes[idx - 1]] == NodeType.REDUCE_SCATTER
243+
# gather wait node after reduce scatter
244+
wait_list.append(node)
245+
node_user = node_users[node]
246+
node_user = [n for n in node_user if node_to_type[n] == NodeType.COMPUTE]
247+
#wait_list.extend(node_user)
248+
elif node_type == NodeType.REDUCE_SCATTER:
249+
if len(wait_list) > 0:
250+
# move the i-th wait node before (i+1)-th reduce scatter node
251+
result_list.extend(wait_list)
252+
wait_list = []
253+
# add reduce scatter node
254+
result_list.append(node)
255+
256+
if len(wait_list) > 0:
257+
result_list.extend(wait_list)
258+
return result_list

0 commit comments

Comments
 (0)