From 0ce9c0c677ca052d3003d286c8ceaef24a89392b Mon Sep 17 00:00:00 2001 From: IvanKobzarev Date: Mon, 6 Oct 2025 01:37:19 -0700 Subject: [PATCH 1/2] [asynctp] Optimize agmm lastdim via addmm_ stack-info: PR: https://github.com/meta-pytorch/autoparallel/pull/190, branch: IvanKobzarev/stack/7 --- autoparallel/asynctp_ops.py | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/autoparallel/asynctp_ops.py b/autoparallel/asynctp_ops.py index 6d9db73c..1e18df63 100644 --- a/autoparallel/asynctp_ops.py +++ b/autoparallel/asynctp_ops.py @@ -625,11 +625,6 @@ def _fused_all_gather_matmul_last_gather_dim_impl( def unflatten(t: torch.Tensor) -> torch.Tensor: return t.view(*leading_dims, -1) - A_out_leading_dims = list(A_shard.shape[:-1]) - - def unflatten_A_out(t: torch.Tensor) -> torch.Tensor: - return t.view(*A_out_leading_dims, -1) - A_flat_out = A_shard_flat.new_empty( A_shard_flat.shape[0] * group.size(), A_shard_flat.shape[1], @@ -645,19 +640,17 @@ def unflatten_A_out(t: torch.Tensor) -> torch.Tensor: for B, out_dtype in zip(Bs, out_dtypes) ] - # Additional allocation for partials output, - # That will be reduced into output. - output_partials = [torch.empty_like(out) for out in outputs] - first = True def default_consumer(shard: torch.Tensor, rank: int) -> None: nonlocal first for idx, (B, kwargs) in enumerate(zip(Bs, kwargs_list)): - out = outputs[idx] if first else output_partials[idx] - mm_out_op(shard, B_shards[idx][rank], **kwargs, out=out) - if not first: - outputs[idx] += output_partials[idx] + out = outputs[idx] + if first: + torch.ops.aten.mm.out(shard, B_shards[idx][rank], **kwargs, out=out) + else: + out.addmm_(shard, B_shards[idx][rank]) + first = False _pipelined_all_gather_and_consume_last_dim( @@ -672,7 +665,7 @@ def default_consumer(shard: torch.Tensor, rank: int) -> None: # This path is inefficient and will be filtered out at passes stage # Added only for completness. A_split_cat_out_flat = torch.cat(A_flat_out.chunk(group_size), dim=-1) - ret_A = unflatten_A_out(A_split_cat_out_flat) + ret_A = unflatten(A_split_cat_out_flat) return ret_A, [unflatten(output) for output in outputs] From 3b9f91493407196896bd59c132a0b253e2e5bdfb Mon Sep 17 00:00:00 2001 From: IvanKobzarev Date: Mon, 6 Oct 2025 01:37:50 -0700 Subject: [PATCH 2/2] DEBUG asynctp stack-info: PR: https://github.com/meta-pytorch/autoparallel/pull/191, branch: IvanKobzarev/stack/8 --- autoparallel/api.py | 4 + autoparallel/optimize_sharding.py | 137 +++++++++++++++++++++++++++++- examples/example_llama3.py | 59 +++++++++++-- mast/sweep.py | 6 ++ 4 files changed, 195 insertions(+), 11 deletions(-) diff --git a/autoparallel/api.py b/autoparallel/api.py index ce087738..8df62a68 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -180,6 +180,7 @@ def __init__( enable_ac: bool = True, # None means 'auto' ac_stage_size_in_GiB: Optional[Union[float, str]] = "auto", + enable_asynctp: bool = False, **kwargs, ): self.stack = ExitStack() @@ -210,6 +211,8 @@ def __init__( self.enable_ac = enable_ac self.ac_stage_size_in_GiB = ac_stage_size_in_GiB + self.enable_asynctp = enable_asynctp + # NB: rest of the construction happens in __enter__ self.active = False @@ -236,6 +239,7 @@ def __enter__(self): self.mesh, rescale_grad_comm_cost_for_mp, repeated_subgraphs=self.kwargs.get("repeated_subgraphs", False), + enable_asynctp=self.enable_asynctp, ) # makes sharding of params and gradients the same diff --git a/autoparallel/optimize_sharding.py b/autoparallel/optimize_sharding.py index 691ec9e1..8831b991 100644 --- a/autoparallel/optimize_sharding.py +++ b/autoparallel/optimize_sharding.py @@ -105,6 +105,8 @@ from .propagation_rules import _create_all_options from .utils import get_local_map_placement_option, get_placement_options +aten = torch.ops.aten + def _debug_node(node): def my_print(x): @@ -128,7 +130,12 @@ def _get_next_name(name): class ShardingOptimizer: def __init__( - self, gm, mesh, rescale_grad_comm_cost_for_mp=1.0, repeated_subgraphs=False + self, + gm, + mesh, + rescale_grad_comm_cost_for_mp=1.0, + repeated_subgraphs=False, + enable_asynctp=False, ): self.gm = gm self.graph = gm.graph @@ -139,11 +146,13 @@ def __init__( self.strats = self.build_sharding_metadata() self.cluster_links = {} + repeated_subgraphs = True if repeated_subgraphs: t = time.time() clusters = get_identical_regions(self.gm.graph, self.strats) print(f"Found {len(clusters)} clusters in {time.time() - t:.2f}s") self.create_cluster_links(clusters) + self.enable_asynctp = enable_asynctp # ds: Decision variables dictionary mapping (s_i, argi, ss, ii) -> ILP variable data # Each key represents a choice of input placement ii and output placement ss @@ -556,11 +565,11 @@ def print_old(self): print(self.get_violated_constraints_log()) def get_log(self, colored=False): - from torch.fx.graph import _color_fns, _identity opt = {} nodes = list(self.graph.nodes) + log_shapes = False for x in self.res: opt.setdefault(nodes[x[0]], []).append(self.ds[x]) @@ -600,10 +609,28 @@ def get_log(self, colored=False): code.insert(l_id, line) l_id += 1 continue + log_extra = "" + if log_shapes: + if ( + isinstance(node, torch.fx.Node) + and "val" in node.meta + and isinstance(node.meta["val"], torch.Tensor) + ): + log_extra += "(" + for arg in node.args: + if ( + isinstance(arg, torch.fx.Node) + and "val" in arg.meta + and isinstance(arg.meta["val"], torch.Tensor) + ): + log_extra += str(list(arg.meta["val"].shape)) + log_extra += ") -> " + log_extra += str(list(node.meta["val"].shape)) + log_extra += "\n" # LOL while not code[l_id].lstrip().startswith(repr(node)): l_id += 1 - code[l_id] += line + code[l_id] = log_extra + code[l_id] + line l_id += 1 code = "\n".join(code) total_cost = sum(self.ds[x]["cost"] for x in self.res) @@ -641,6 +668,11 @@ def get_solution(self, verbose=False): # add their costs for x in self.ds.values(): opt_target[x["va"]] += x["cost"] + + if self.enable_asynctp: + for va, cost in self.add_asynctp_scores().items(): + opt_target[va] += cost + self.prob += pulp.lpSum([va * cost for va, cost in opt_target.items()]) # solver = pulp.HiGHS(msg=verbose) @@ -885,6 +917,105 @@ def add_sharded_output_constraint(self, output_placements=None): "them from the graph to avoid aliasing." ) + def add_asynctp_scores(self): + # Encourage placements that enable asyncTP fusions: + # -X % of comm_cost + # 1. ag + mm: S(d) -> R, d < mm.ndim - 1 + # 2. mm + rs: P -> S(d) + # TODO1: Filter out FSDP ag/rs that will not be asyncTPed + # TODO2: With AsyncTP we should have perf wins, + # overlapping ((group_size - 1) / group_size) of communication + # minus cost of decomposition. + # For this we need to get group_size from the redistribution. + def _get_transformations(src_spec, tgt_spec): + # TODO: Use real transform preparation + # For now just checking left to right + src_pls = src_spec.placements + tgt_pls = tgt_spec.placements + transformations = [] + for src_pl, tgt_pl in zip(src_pls, tgt_pls): + if src_pl == tgt_pl: + continue + transformations.append((src_pl, tgt_pl)) + return transformations + + def _produces_asynctp_ag(src_spec, tgt_spec, mm_dim): + # Check that the last transition will be S(dim) -> Replicate + + transformations = _get_transformations(src_spec, tgt_spec) + if len(transformations) == 0: + return False + last_t = transformations[-1] + return ( + last_t[1].is_replicate() + and last_t[0].is_shard() + and last_t[0].dim < mm_dim - 1 + ) + + def _produces_asynctp_rs(src_spec, tgt_spec, mm_dim): + # Check that the last transition will be P -> S(dim) + transformations = _get_transformations(src_spec, tgt_spec) + if len(transformations) == 0: + return False + last_t = transformations[-1] + return last_t[0].is_partial() and last_t[1].is_shard() + + va_cost_delta = defaultdict(int) + strats = self.strats + for s_i, (node, s) in enumerate(strats.items()): + if not (node.op == "call_function" and node.target == aten.mm.default): + continue + mm_n = node + # Incentivize ag+mm + # ard0 of MM should be S(dim) -> R to have all_gather before mm + a_n = node.args[0] + mm_sts = s.strategies + for mm_st_i, mm_st in enumerate(mm_sts): + a_sts = strats[a_n].strategies + mm_tgt_spec = mm_st.input_specs[0] + for a_st_i, a_st in enumerate(a_sts): + a_src_spec = a_st.output_spec + # TODO: Is adding constraint to arg is enough or we need to follow the arg + # ancestors and find the first sharding change? + if _produces_asynctp_ag( + a_src_spec, mm_tgt_spec, mm_n.meta["val"].ndim + ): + # TODO: We want to to calculate the cost of specific AG, as it will be pipelined, + # for now using just redistribution cost + cost = mm_st.redistribute_cost[0][a_st_i] + if cost == float("inf"): + continue + va = self.ds[(s_i, 0, mm_st_i, a_st_i)]["va"] + va_cost_delta[va] += -0.3 * cost + # mm+rs + src_spec = mm_st.output_spec + if len(mm_n.users) == 0: + continue + mm_user = next(iter(mm_n.users)) + mm_user_s_i = self.node_map[mm_user] + mm_u_arg_mm_i = -1 + for i, arg in enumerate(mm_user.args): + if arg == mm_n: + mm_u_arg_mm_i = i + assert mm_u_arg_mm_i != -1 + mm_user_sts = strats[mm_user].strategies + for mm_u_st_i, mm_u_st in enumerate(mm_user_sts): + if _produces_asynctp_rs( + src_spec, + mm_u_st.input_specs[mm_u_arg_mm_i], + mm_n.meta["val"].ndim, + ): + # TODO: We want to to calculate the cost of specific RS, as it will be pipelined, + # for now using just redistribution cost + cost = mm_u_st.redistribute_cost[mm_u_arg_mm_i][mm_u_st_i] + if cost == float("inf"): + continue + key = (mm_user_s_i, mm_u_arg_mm_i, mm_u_st_i, mm_st_i) + va = self.ds[key]["va"] + va_cost_delta[va] += -0.3 * cost + + return va_cost_delta + def validate(self): for node in self.graph.nodes: if node.op != "call_function": diff --git a/examples/example_llama3.py b/examples/example_llama3.py index bc41e96c..57b2ae2c 100644 --- a/examples/example_llama3.py +++ b/examples/example_llama3.py @@ -211,7 +211,13 @@ def add_tp_constraints(autop): # parallelize the model with AutoParallel( - model, input_fn, mesh, mp_policy, compile=True, repeated_subgraphs=True + model, + input_fn, + mesh, + mp_policy, + compile=True, + repeated_subgraphs=True, + enable_asynctp=enable_asynctp, ) as autop: autop.add_parameter_memory_constraint(low=None, high=None) @@ -229,22 +235,59 @@ def add_tp_constraints(autop): if enable_manual_constraint and not use_1d_mesh: add_tp_constraints(autop) - if enable_asynctp: - from torch.distributed._symmetric_memory import enable_symm_mem_for_group + enable_overlap_scheduling = True + enable_overlap_scheduling_bucketing = True + if enable_overlap_scheduling_bucketing: + assert ( + enable_overlap_scheduling + ), "bucketing can not be used without overlap scheduling" + enable_asynctp = True + from autoparallel.asynctp import ( + _micro_pipeline_tp_ag_mm_last_dim_enabled, + _micro_pipeline_tp_ag_transpose_mm_enabled, + ) + + _micro_pipeline_tp_ag_transpose_mm_enabled = True + _micro_pipeline_tp_ag_mm_last_dim_enabled = True + if ( + enable_overlap_scheduling + or enable_overlap_scheduling_bucketing + or enable_asynctp + ): + torch._inductor.config.reorder_for_peak_memory = False + torch._inductor.config.reorder_for_compute_comm_overlap = False + torch._inductor.config.allow_buffer_reuse = False + torch._inductor.config.test_configs.aten_fx_overlap_preserving_bucketing = ( + enable_overlap_scheduling_bucketing + ) - enable_symm_mem_for_group(mesh["dp"].get_group().group_name) - enable_symm_mem_for_group(mesh["tp"].get_group().group_name) - torch._inductor.config._micro_pipeline_tp = False - from autoparallel.asynctp import micro_pipeline_tp_pass + if enable_asynctp: + from torch.distributed._symmetric_memory import enable_symm_mem_for_group + + enable_symm_mem_for_group(mesh["tp"].get_group().group_name) + enable_symm_mem_for_group(mesh["dp"].get_group().group_name) + torch._inductor.config._micro_pipeline_tp = False + # Disable inductor AsyncTP passes, in favor of using Autoparallel passes fork. + # TODO: Switch to Inductor AsyncTP passes, when all additions landed. + from autoparallel.asynctp import micro_pipeline_tp_pass existing_post_grad_custom_post_pass = ( torch._inductor.config.post_grad_custom_post_pass ) + from torch._inductor.fx_passes.overlap_scheduling import OverlapScheduler def _pass(graph): if existing_post_grad_custom_post_pass is not None: existing_post_grad_custom_post_pass(graph) - micro_pipeline_tp_pass(graph) + + collective_info = None + if enable_overlap_scheduling: + overlap_scheduler = OverlapScheduler(graph.owning_module) + overlap_scheduler.run() + collective_info = overlap_scheduler.collective_info + + if enable_asynctp: + micro_pipeline_tp_pass(graph, collective_info) torch._inductor.config.post_grad_custom_post_pass = _pass diff --git a/mast/sweep.py b/mast/sweep.py index e867bd43..130cace2 100644 --- a/mast/sweep.py +++ b/mast/sweep.py @@ -123,6 +123,12 @@ def maybe_find_pulp(maybe_path: Optional[str] = None) -> Optional[str]: "--model.name=llama3", "--compile.enable", ], + "llama3_FSDP_tp_async_tp_compile": llama3_2d_common_opts + + [ + "--model.name=llama3", + "--compile.enable", + "--parallelism.enable_async_tensor_parallel", + ], "llama3_autop_2d_compile": llama3_2d_common_opts + [ "--model.name=llama3_auto_parallel",