Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions autoparallel/auto_bucketing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import torch

from .autobucketing_util import bucket_func, bucket_plan, bucket_utils, reorder


class simplefsdp_autobucketing_config:
"""
Config for simplefsdp's autobucketing pass, which by default would give good performance.
To make the results tunable, we expose the following parameters:
- relax_ratio: relax comp time to include more comm in one bucket
with this config, comp is updated as comp * (1 + relax_ratio)
- peak_memory_offset: relax peak_memory to include more comm in one bucket
with this config, peak_memory is updated as (peak_memory + peak_memory_offset)
- load_cache: set to True to load cache from save_estimation_path
- enable_bucket_ir: set to True to bucket all_gather/reduce_scatter
- enable_reorder_ir: set to True to reorder all_gather/reduce_satter
- calibrate_number: number of samples to calibrate during comm estimation
"""

relax_ratio = 0
peak_memory_offset = 0
load_cache = False
save_estimation_path = "/mnt/mffuse/cache_ruisi/estimation_mast.pkl"
enable_bucket_ir = True
enable_reorder_ir = True
calibrate_number = 40


def simple_fsdp_autobucketing_reordering_pass(
snodes: list["torch._inductor.scheduler.BaseSchedulerNode"],
configs: "simplefsdp_autobucketing_config",
) -> list["torch._inductor.scheduler.BaseSchedulerNode"]:
scheduler = snodes[0].scheduler
bucketable_nodes = bucket_utils.get_bucketable_ir_nodes(
snodes, scheduler.name_to_fused_node, scheduler.name_to_buf
)

assert (
not torch._inductor.config.allow_buffer_reuse
), "bucketing algorithm requires torch._inductor.config.allow_buffer_reuse to be False"

if configs.enable_bucket_ir:
all_gather_plan, reduce_scatter_plan = bucket_plan.get_simplefsdp_auto_plan(
scheduler,
snodes,
scheduler.name_to_buf,
scheduler.name_to_fused_node,
bucketable_nodes,
configs,
)

snodes = bucket_func.bucket_fsdp_all_gather_with_plan(
scheduler,
snodes,
scheduler.name_to_buf,
scheduler.name_to_fused_node,
all_gather_plan,
bucketable_nodes,
)
if len(reduce_scatter_plan) > 0:
snodes = bucket_func.bucket_fsdp_reduce_scatter_with_plan(
scheduler,
snodes,
scheduler.name_to_buf,
scheduler.name_to_fused_node,
reduce_scatter_plan,
bucketable_nodes,
)

if configs.enable_reorder_ir:
print("Reorder scheduler nodes with autobucketing algroithm")
node_length = len(snodes)
snodes = reorder.reorder_all_gather(
snodes, bucketable_nodes, all_gather_before_last_wait=False
)
assert node_length == len(
snodes
), f"Missed nodes in reordering all gather: expected {node_length}, but got {len(snodes)}"
snodes = reorder.reorder_reduce_scatter(snodes, bucketable_nodes)
assert node_length == len(
snodes
), f"Missed nodes in reordering reduce scatter: expected {node_length}, but got {len(snodes)}"

return snodes
Loading
Loading