Skip to content

Commit 017b8e2

Browse files
committed
Update
[ghstack-poisoned]
2 parents fe8b82e + b7014bd commit 017b8e2

File tree

7 files changed

+41
-32
lines changed

7 files changed

+41
-32
lines changed

autoparallel/auto_bucketing.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -76,16 +76,14 @@ def simple_fsdp_autobucketing_reordering_pass(
7676
print("Reorder scheduler nodes with autobucketing algroithm")
7777
node_length = len(snodes)
7878
snodes = reorder.reorder_all_gather(
79-
snodes,
80-
bucketable_nodes,
81-
all_gather_before_last_wait=True
82-
)
83-
assert node_length == len(snodes), (
84-
f"Missed nodes in reordering all gather: expected {node_length}, but got {len(snodes)}"
79+
snodes, bucketable_nodes, all_gather_before_last_wait=False
8580
)
81+
assert node_length == len(
82+
snodes
83+
), f"Missed nodes in reordering all gather: expected {node_length}, but got {len(snodes)}"
8684
snodes = reorder.reorder_reduce_scatter(snodes, bucketable_nodes)
87-
assert node_length == len(snodes), (
88-
f"Missed nodes in reordering reduce scatter: expected {node_length}, but got {len(snodes)}"
89-
)
85+
assert node_length == len(
86+
snodes
87+
), f"Missed nodes in reordering reduce scatter: expected {node_length}, but got {len(snodes)}"
9088

9189
return snodes

autoparallel/autobucketing_util/bucket_func.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import torch
1111
from torch._inductor import ir, scheduler
12-
from torch._inductor.comms import bucket_all_gathers, bucket_reduce_scatters, get_op_idx
12+
from torch._inductor.comms import get_op_idx
1313
from torch._inductor.dependencies import StarDep, WeakDep
1414
from torch._inductor.utils import is_collective, is_wait
1515
from torch._inductor.virtualized import V
@@ -23,6 +23,8 @@
2323
_replace_scheduler_buffer,
2424
_schedule_fallback_operation,
2525
_schedule_snode,
26+
bucket_all_gathers,
27+
bucket_reduce_scatters,
2628
check_ir_node_bucketable,
2729
)
2830

autoparallel/autobucketing_util/bucket_plan.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,9 @@ def get_simplefsdp_auto_plan(
108108
comp_time_dict,
109109
memory_dict,
110110
peak_memory_per_step_dict,
111-
) = benchmark_and_sync_runtime(sched, snodes, bucketable_nodes)
111+
) = benchmark_and_sync_runtime(
112+
sched, snodes, name_to_buf, name_to_fused_node, bucketable_nodes, configs
113+
)
112114
future_comp_time = sum(comp_time_dict.values())
113115
peak_memory = max(peak_memory_per_step_dict.values()) + configs.peak_memory_offset
114116

@@ -139,7 +141,7 @@ def get_simplefsdp_auto_plan(
139141
current_ag_bucket,
140142
schedule_fallback_operation,
141143
name_to_buf,
142-
torch.ops._c10d_functional.all_gather_into_tensor.default,
144+
"torch.ops._c10d_functional.all_gather_into_tensor.default",
143145
comm_cache,
144146
)
145147

@@ -243,7 +245,7 @@ def get_simplefsdp_auto_plan(
243245
current_rs_bucket,
244246
schedule_fallback_operation,
245247
name_to_buf,
246-
torch.ops._c10d_functional.reduce_scatter_tensor.default,
248+
"torch.ops._c10d_functional.reduce_scatter_tensor.default",
247249
comm_cache,
248250
ReduceOp.AVG,
249251
)
@@ -290,7 +292,7 @@ def get_simplefsdp_auto_plan(
290292
current_rs_bucket,
291293
schedule_fallback_operation,
292294
name_to_buf,
293-
torch.ops._c10d_functional.reduce_scatter_tensor.default,
295+
"torch.ops._c10d_functional.reduce_scatter_tensor.default",
294296
comm_cache,
295297
ReduceOp.AVG,
296298
)
@@ -333,10 +335,10 @@ def get_simplefsdp_auto_plan(
333335
]
334336
seen_new_bucketable_ag = False
335337

336-
if len(current_ag_bucket) > 0 or len(all_gather_plan) == 0:
338+
if len(current_ag_bucket) > 0:
337339
all_gather_plan.append(current_ag_bucket)
338340

339-
if len(current_rs_bucket) > 0 or len(reduce_scatter_plan) == 0:
341+
if len(current_rs_bucket) > 0:
340342
reduce_scatter_plan.append(current_rs_bucket)
341343

342344
return all_gather_plan, reduce_scatter_plan

autoparallel/autobucketing_util/bucket_utils.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def check_ir_node_bucketable(
163163

164164
def _get_fx_node(
165165
snode_or_ir_node: Union["scheduler.BaseSchedulerNode", "ir.IRNode"],
166-
expected_op: Callable[[Any]],
166+
expected_op: Any,
167167
) -> torch.fx.Node:
168168
origins = None
169169
if isinstance(snode_or_ir_node, scheduler.BaseSchedulerNode):
@@ -190,7 +190,7 @@ def _get_fx_node(
190190

191191
def get_snode_process_group_info(
192192
snode: "scheduler.BaseSchedulerNode",
193-
expected_op: Callable[[Any]],
193+
expected_op: Any,
194194
resolve_pg: bool = False,
195195
) -> tuple[int, Union[str, ProcessGroup]]:
196196
fx_node = _get_fx_node(snode, expected_op=expected_op)
@@ -248,7 +248,7 @@ def get_snode_tensor_info(
248248

249249
def _estimate_bucketed_node_list(
250250
current_node_list: list["scheduler.BaseSchedulerNode"],
251-
schedule_fallback_operation: Callable[[Any]],
251+
schedule_fallback_operation: Callable[[Any], Any],
252252
group_size: int,
253253
group_name: str,
254254
name_to_buf: Dict[str, "scheduler.SchedulerBuffer"],
@@ -272,7 +272,7 @@ def _estimate_bucketed_node_list(
272272
)
273273
return estimated_comm, comm_size_inp, comm_size_out
274274

275-
if comm_func == torch.ops._c10d_functional.all_gather_into_tensor.default:
275+
if comm_func == "torch.ops._c10d_functional.all_gather_into_tensor.default":
276276
bucked_node = bucket_all_gathers(
277277
schedule_fallback_operation,
278278
group_size,
@@ -284,7 +284,7 @@ def _estimate_bucketed_node_list(
284284
)
285285
comm_size_inp = bucked_node[0].layout.size
286286
comm_size_out = bucked_node[1].layout.size
287-
elif comm_func == torch.ops._c10d_functional.reduce_scatter_tensor.default:
287+
elif comm_func == "torch.ops._c10d_functional.reduce_scatter_tensor.default":
288288
bucked_node = bucket_reduce_scatters(
289289
schedule_fallback_operation,
290290
group_size,
@@ -311,7 +311,7 @@ def _estimate_bucketed_node_list(
311311

312312
def estimate_bucketed_snode_runtime(
313313
node_bucket_dict: Dict[tuple[Any, ...], list["scheduler.BaseSchedulerNode"]],
314-
schedule_fallback_operation: Callable[[Any]],
314+
schedule_fallback_operation: Callable[[Any], Any],
315315
name_to_buf: Dict[str, "scheduler.SchedulerBuffer"],
316316
comm_func: Callable[[Any], Any],
317317
comm_cache: Dict[Any, Any],
@@ -687,7 +687,7 @@ def _get_dim0_padded_size(tensor_size: torch.Size, dim0_factor: int) -> torch.Si
687687

688688
reduce_scatter_tensor = schedule_fallback_operation(
689689
torch.ops._c10d_functional.reduce_scatter_tensor.default,
690-
(reduce_scatter_input, reduce_op, group_size, group_name),
690+
(reduce_scatter_input, str(reduce_op), group_size, group_name),
691691
{},
692692
dep_operations=chunk_cat,
693693
)

autoparallel/autobucketing_util/estimation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import os
88
import pickle
99
from collections import defaultdict
10+
from typing import Any
1011

1112
import torch
1213
import torch.distributed as c10d
@@ -15,7 +16,6 @@
1516
from torch._inductor.virtualized import V
1617
from torch.utils._ordered_set import OrderedSet
1718

18-
from ..auto_bucketing import simplefsdp_autobucketing_config
1919
from .bucket_utils import (
2020
check_ir_node_bucketable,
2121
get_snode_process_group_info,
@@ -44,7 +44,7 @@ def benchmark_and_sync_runtime(
4444
name_to_buf: dict[str, "scheduler.SchedulerBuffer"],
4545
name_to_fused_node: dict[str, "scheduler.BaseSchedulerNode"],
4646
bucketable_nodes: set[str],
47-
configs: "simplefsdp_autobucketing_config",
47+
configs: Any,
4848
):
4949
world_size = c10d.distributed_c10d.get_world_size()
5050

@@ -220,6 +220,6 @@ def benchmark_and_sync_runtime(
220220
median_runtimes = sync_dict_across_ranks(comm_cache.cache, world_size)
221221
comm_cache.cache = median_runtimes
222222
comm_cache._update_max_size()
223-
with open(configs.simplefsdp.save_estimation_path, "wb") as file:
223+
with open(configs.save_estimation_path, "wb") as file:
224224
pickle.dump(comm_cache.cache, file)
225225
return comm_cache, comp_time_dict, memory_dict, peak_memory_per_step_dict

autoparallel/autobucketing_util/estimation_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ def add_comm_time(self, tensor_input_size, tensor_output_size, comm_func, value)
200200
def get_comm_time(
201201
self, tensor_input_size, tensor_output_size, comm_func, calibrated=False
202202
):
203+
comm_func = str(comm_func)
203204
key = (tuple(tensor_input_size), tuple(tensor_output_size), comm_func)
204205
if key in self.cache:
205206
return self.cache[key]
@@ -368,7 +369,7 @@ def to_real_tensor(e: Any) -> Any:
368369
return out
369370

370371
def delete_tensor_in_list(tensor_list: list[Any]) -> None:
371-
for i in range(len(tensor_list)):
372+
for i in range(len(tensor_list) - 1, -1, -1):
372373
if isinstance(tensor_list[i], torch.Tensor):
373374
tensor_list[i].cpu()
374375
del tensor_list[i]

autoparallel/autobucketing_util/reorder.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
from typing import Dict, List, Optional, Tuple
1010

1111
import torch
12-
from torch.utils._ordered_set import OrderedSet
1312
from torch._inductor import ir, scheduler
14-
from torch._inductor.utils import is_collective
13+
from torch._inductor.utils import contains_collective, contains_wait, is_collective
14+
from torch.utils._ordered_set import OrderedSet
1515

1616
from .bucket_utils import check_ir_node_bucketable
1717

@@ -143,7 +143,9 @@ def get_node_type(node: "scheduler.BaseSchedulerNode", bucketable_ir_nodes) -> N
143143

144144
if isinstance(node, scheduler.GroupedSchedulerNode):
145145
# [Only for bucketing]: newly created AG and RS are grouped as GroupedSchedulerNode
146-
child_nodes_type = [_get_ir_node_type(n.node, bucketable_ir_nodes) for n in node.snodes]
146+
child_nodes_type = [
147+
_get_ir_node_type(n.node, bucketable_ir_nodes) for n in node.snodes
148+
]
147149
if NodeType.AG_WAIT in child_nodes_type:
148150
return NodeType.AG_WAIT
149151
elif NodeType.RS_WAIT in child_nodes_type:
@@ -187,7 +189,11 @@ def reorder_all_gather(
187189
all_gather_list.append(node)
188190
inverse_user = list(inverse_users[node])
189191
inverse_user = [
190-
n for n in inverse_user if node_to_type[n] == NodeType.COMPUTE
192+
n
193+
for n in inverse_user
194+
if node_to_type[n] == NodeType.COMPUTE
195+
and not contains_collective(n)
196+
and not contains_wait(n)
191197
]
192198
if len(inverse_user) > 0:
193199
all_gather_list.extend(inverse_user)
@@ -244,7 +250,7 @@ def reorder_reduce_scatter(
244250
wait_list.append(node)
245251
node_user = node_users[node]
246252
node_user = [n for n in node_user if node_to_type[n] == NodeType.COMPUTE]
247-
#wait_list.extend(node_user)
253+
# wait_list.extend(node_user)
248254
elif node_type == NodeType.REDUCE_SCATTER:
249255
if len(wait_list) > 0:
250256
# move the i-th wait node before (i+1)-th reduce scatter node

0 commit comments

Comments
 (0)