diff --git a/.github/workflows/test_cuda.yml b/.github/workflows/test_cuda.yml index 71d3fba..edca044 100644 --- a/.github/workflows/test_cuda.yml +++ b/.github/workflows/test_cuda.yml @@ -46,3 +46,4 @@ jobs: python examples/example_dcp.py python examples/example_local_map.py python examples/example_ds3_local_map.py + python examples/example_pp_graph_partition.py diff --git a/.gitignore b/.gitignore index 65fe220..41c2f3e 100644 --- a/.gitignore +++ b/.gitignore @@ -10,5 +10,6 @@ build/ dist/ +tmp/ .vscode/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c622af9..e23bf1b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -37,6 +37,7 @@ repos: - repo: local hooks: - id: mypy + require_serial: true name: mypy entry: mypy language: system diff --git a/autoparallel/_passes/graph_multiplex.py b/autoparallel/_passes/graph_multiplex.py new file mode 100644 index 0000000..e1eadc8 --- /dev/null +++ b/autoparallel/_passes/graph_multiplex.py @@ -0,0 +1,105 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import copy + +import torch +import torch.fx as fx + + +def multiplex_fw_bw_graph( + fw_gm: fx.GraphModule, bw_gm: fx.GraphModule +) -> fx.GraphModule: + """ + Multiplexes forward and backward graphs into a single unified graph module. + + This function combines a forward graph and a backward graph into one multiplexed + graph by merging their nodes and outputs. The resulting graph has: + - All placeholders from both forward and backward graphs (backward followed by forward) + - All computation nodes from both graphs (backward followed by forward) + - Combined outputs (backward outputs followed by forward outputs) + + Args: + fw_gm: The forward graph module containing the forward computation + bw_gm: The backward graph module containing the backward computation + + Returns: + A multiplexed fx.GraphModule containing both forward and backward computations + with backward outputs appearing before forward outputs + + Note: + The function preserves node metadata during the merging process. + """ + # Mapping to track correspondence between backward graph nodes and new nodes + old_node_to_new_node: dict[torch.fx.Node, torch.fx.Node] = {} + + # Start with a deep copy of the forward graph as the base + multiplexed_gm = copy.deepcopy(fw_gm) + + # Collect all placeholder nodes from the backward graph + bw_placeholders = [] + for n in bw_gm.graph.nodes: + if n.op == "placeholder": + bw_placeholders.append(n) + + # Insert backward placeholders at the beginning of the multiplexed graph + # Reversed order ensures correct execution sequence + with multiplexed_gm.graph.inserting_before(): + for n in reversed(bw_placeholders): + new_placeholder = multiplexed_gm.graph.placeholder(n.name) + new_placeholder.meta = n.meta + new_placeholder.target = new_placeholder.name + old_node_to_new_node[n] = new_placeholder + + # Find the last placeholder and the output node in the multiplexed graph + insert_point = None + multiplexed_graph_op_node = None + for n in multiplexed_gm.graph.nodes: + if n.op == "placeholder": + insert_point = n + if n.op == "output": + multiplexed_graph_op_node = n + + # Copy all computation nodes from backward graph into multiplexed graph + bw_graph_op_node = None + for n in bw_gm.graph.nodes: + if n.op == "placeholder": + continue + if n.op == "output": + bw_graph_op_node = n + continue + with multiplexed_gm.graph.inserting_after(insert_point): + # Copy node and remap its arguments using the node mapping + new_node = multiplexed_gm.graph.node_copy( + n, lambda x: old_node_to_new_node[x] + ) + new_node.meta = n.meta + old_node_to_new_node[n] = new_node + insert_point = new_node + + assert bw_graph_op_node is not None + assert multiplexed_graph_op_node is not None + + # Collect output arguments from backward graph, remapping to new nodes + bw_op_node_args = [ + old_node_to_new_node[n] if n is not None else None + for n in bw_graph_op_node.args[0] + ] + + # Collect output arguments from forward graph + fw_op_node_args = list(multiplexed_graph_op_node.args[0]) + + # Remove the old output node and create new combined output + insert_point = multiplexed_graph_op_node.prev + multiplexed_gm.graph.erase_node(multiplexed_graph_op_node) + + # Create combined output with backward outputs first, then forward outputs + with multiplexed_gm.graph.inserting_after(insert_point): + multiplexed_gm.graph.output(bw_op_node_args + fw_op_node_args) + + multiplexed_gm.graph.eliminate_dead_code() + multiplexed_gm.graph.lint() + multiplexed_gm.recompile() + return multiplexed_gm diff --git a/autoparallel/_passes/graph_partition.py b/autoparallel/_passes/graph_partition.py index 2d31c18..736c5ba 100644 --- a/autoparallel/_passes/graph_partition.py +++ b/autoparallel/_passes/graph_partition.py @@ -79,9 +79,6 @@ def partition_joint_with_descriptors( num_mutate_inputs = len( [x for x in fw_metadata.input_info if x.mutates_data or x.mutates_metadata] ) - print(fw_module.graph) - print(fw_module.graph) - return ( fw_module, bw_module, diff --git a/autoparallel/_passes/split_di_dw_graph.py b/autoparallel/_passes/split_di_dw_graph.py new file mode 100644 index 0000000..c4987cb --- /dev/null +++ b/autoparallel/_passes/split_di_dw_graph.py @@ -0,0 +1,64 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import copy + +import torch.fx as fx +from functorch.compile import default_partition + +# we are running the default partitioner on the bw graph, which requires AC tags being removed. +# At this stage we have already finished running AC anyway, since we have a bw graph + + +def remove_recompute_tags(bw_gm): + for n in bw_gm.graph.nodes: + if "recompute" in n.meta: + del n.meta["recompute"] + + +# We are using the default partitioner to split our backward into dI and dW subgraphs. +# We want to generate the dI subgraph *first*, because: +# - in pipelining we generally want to schedule dI compute before dW +# - the dI compute will potentially compute more activations that we need to plumb into dW compute +# Today, the default partitioner requires that your split on the first K outputs of your combined graph. +# So here, we reorder the outputs of the backward so grad_inputs are first. + + +def reorder_output_grads(bw_gm, num_weight_gradients): + outputs = bw_gm.graph.find_nodes(op="output") + assert len(outputs) == 1 + output = outputs[0] + assert isinstance(output.args[0], tuple) + grad_weights, grad_inputs = ( + output.args[0][:num_weight_gradients], + output.args[0][num_weight_gradients:], + ) + new_out_tuple = grad_inputs + grad_weights + with bw_gm.graph.inserting_after(output): + # TODO: also set the new node's meta properly + new_out = bw_gm.graph.output(new_out_tuple) + output.replace_all_uses_with(new_out) + bw_gm.graph.erase_node(output) + return len(grad_inputs) + + +# TODO: in theory we can infer num_weight_gradients from the graph metadata directly + + +def split_di_dw_graph( + bw_gm: fx.GraphModule, *, num_weight_gradients +) -> tuple[fx.GraphModule, fx.GraphModule]: + # we could consider doing this is a non-mutating way + bw_gm = copy.deepcopy(bw_gm) + remove_recompute_tags(bw_gm) + num_input_gradients = reorder_output_grads(bw_gm, num_weight_gradients) + bw_gm.recompile() + + args = [x.meta["val"] for x in bw_gm.graph.find_nodes(op="placeholder")] + + bw_inputs, bw_weights = default_partition( + bw_gm, args, num_fwd_outputs=num_input_gradients + ) + return bw_inputs, bw_weights diff --git a/autoparallel/_passes/split_fsdp_collectives.py b/autoparallel/_passes/split_fsdp_collectives.py new file mode 100644 index 0000000..a411fb3 --- /dev/null +++ b/autoparallel/_passes/split_fsdp_collectives.py @@ -0,0 +1,150 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import dataclasses +from contextlib import contextmanager +from functools import partial +from typing import Any + +import torch +import torch.fx.node +import torch.utils._pytree as pytree +from torch._functorch._aot_autograd.descriptors import AOTOutput +from torch._functorch.partitioners import _extract_graph_with_inputs_outputs +from torch._inductor.fx_passes.bucketing import ( + is_all_gather_into_tensor, + is_reduce_scatter_tensor, +) + + +@contextmanager +def exclude_from_fx_side_effectful(exclude_vals: set[Any]): + original_val = torch.fx.node._side_effectful_functions.copy() + try: + torch.fx.node._side_effectful_functions -= exclude_vals + yield + finally: + torch.fx.node._side_effectful_functions.clear() + torch.fx.node._side_effectful_functions.update(original_val) + + +exclude_wait_from_fx_side_effectful = partial( + exclude_from_fx_side_effectful, + { + torch.ops._c10d_functional.wait_tensor, + torch.ops._c10d_functional.wait_tensor.default, + }, +) + + +@dataclasses.dataclass(frozen=True) +class PrefetchOutput(AOTOutput): + pass + + +@dataclasses.dataclass(frozen=True) +class EpilogueInput(AOTOutput): + pass + + +def split_fsdp_prefetch(g: torch.fx.Graph) -> tuple[torch.fx.Graph, torch.fx.Graph]: + g_ins = g.find_nodes(op="placeholder") + prefetch_g_outs_map = [] + + for g_in in g_ins: + n = g_in + last_ag = None + while True: + if len(n.users) != 1: + break + user = next(iter(n.users)) + if len(user.all_input_nodes) > 1: + break + n = user + if is_all_gather_into_tensor(n): + last_ag = n + if last_ag is None: + prefetch_g_outs_map.append(g_in) + else: + w_n = next(iter(last_ag.users)) + prefetch_g_outs_map.append(w_n) + + prefetch_g_outs = prefetch_g_outs_map + prefetch_g_outs_descs: list[AOTOutput] = [ + PrefetchOutput() for _ in range(len(prefetch_g_outs)) + ] + g_outs = pytree.arg_tree_leaves(*(n.args for n in g.find_nodes(op="output"))) + g_outs_descs = pytree.arg_tree_leaves( + next(iter(g.find_nodes(op="output"))).meta.get("desc", [None] * len(g_outs)) + ) + with exclude_wait_from_fx_side_effectful(): + prefetch_g = _extract_graph_with_inputs_outputs( + g, + g_ins, + prefetch_g_outs, + prefetch_g_outs_descs, + ignore_must_be_in_fw_bw=True, + ) + + main_g = _extract_graph_with_inputs_outputs( + g, + prefetch_g_outs, + g_outs, + g_outs_descs, + ignore_must_be_in_fw_bw=True, + ) + return prefetch_g, main_g + + +def split_fsdp_reduce_scatters_epilogue( + g: torch.fx.Graph, +) -> tuple[torch.fx.Graph, torch.fx.Graph]: + g_ins = g.find_nodes(op="placeholder") + g_outs = pytree.arg_tree_leaves(*(n.args for n in g.find_nodes(op="output"))) + g_outs_descs = pytree.arg_tree_leaves( + next(iter(g.find_nodes(op="output"))).meta.get("desc", [None] * len(g_outs)) + ) + + g_outs_map = [] + for g_out in g_outs: + n = g_out + last_rs = None + while n is not None: + if len(n.all_input_nodes) != 1: + break + n_in = n.all_input_nodes[0] + if len(n_in.users) > 1: + break + prev_n = n + n = n_in + if is_reduce_scatter_tensor(prev_n): + # In AP for mesh dim > 1 + # The reduction of gradients happen in multiple steps + last_rs = n + if last_rs is not None: + g_outs_map.append(last_rs) + else: + g_outs_map.append(g_out) + + epi_g_ins = [n for n in g_outs_map if n is not None] + epi_g_ins_descs: list[AOTOutput] = [EpilogueInput() for _ in range(len(epi_g_ins))] + + with exclude_wait_from_fx_side_effectful(): + main_g = _extract_graph_with_inputs_outputs( + g, + g_ins, + epi_g_ins, + epi_g_ins_descs, + ignore_must_be_in_fw_bw=True, + ) + epi_g = _extract_graph_with_inputs_outputs( + g, + epi_g_ins, + g_outs, + g_outs_descs, + ignore_must_be_in_fw_bw=True, + ) + + return main_g, epi_g diff --git a/autoparallel/_testing/models/dsv3.py b/autoparallel/_testing/models/dsv3.py index 95fc2c9..544e571 100644 --- a/autoparallel/_testing/models/dsv3.py +++ b/autoparallel/_testing/models/dsv3.py @@ -1574,3 +1574,58 @@ def forward( h = self.norm(h) if self.norm is not None else h output = self.output(h) if self.output is not None else h return output + + +######################## +# Pipeline stuff start # +######################## + + +class DeepSeekV3StageI(nn.Module): + def __init__(self, layers, config): + super().__init__() + self.layers = layers + self.register_buffer( + "freqs_cis", precompute_freqs_cis(config), persistent=False + ) + + def forward(self, h): + # intermediate stages only have layers + for layer in self.layers.values(): + h = layer(h, self.freqs_cis) + return h + + def init_weights(self, buffer_device: torch.device | None = None) -> None: + for layer in self.layers.values(): + if layer is not None: + layer.init_weights(buffer_device=buffer_device) + + +class DeepSeekV3Stage0(DeepSeekV3StageI): + def __init__(self, embed, layers, config): + super().__init__(layers, config) + self.tok_embeddings = embed + + def forward(self, tokens): + # torch.Size([1024, 1024]) + h = self.tok_embeddings(tokens) if self.tok_embeddings is not None else tokens + # torch.Size([1024, 1024, 2048]) + return super().forward(h) + + +class DeepSeekV3StageN(DeepSeekV3StageI): + def __init__(self, layers, norm, output, config): + super().__init__(layers, config) + self.norm = norm + self.output = output + + def forward(self, h): + h = super().forward(h) + h = self.norm(h) if self.norm is not None else h + output = self.output(h) if self.output is not None else h + return output + + +###################### +# Pipeline stuff end # +###################### diff --git a/autoparallel/api.py b/autoparallel/api.py index 7f524c2..9cde61e 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -28,6 +28,8 @@ from torch.export.unflatten import _AttrKind from torch.fx.experimental.symbolic_shapes import ShapeEnv +from autoparallel._passes.graph_partition import partition_joint_with_descriptors + from .activation_checkpointing import ac_joint_pass from .apply_sharding import apply_sharding_to_model from .cast_parametrization import apply_dtype_cast, canonicalize_mp, set_dtype_cast @@ -45,15 +47,6 @@ _APPLY_VIEW_MM_VIEW_PATTERN = False -def try_convert_fake_to_real(tensors): - out = {} - for k, t in tensors.items(): - out[k] = torch.distributed.tensor.randn( - t.shape, dtype=t.dtype, device_mesh=t.device_mesh, placements=t.placements - ) - return out - - def _get_decomp_table(): decomp_table = copy.copy(select_decomp_table()) # TODO: removing those as they cause missing DTensor propagation rules @@ -550,14 +543,54 @@ def forward(self, *args): self._register_params_and_init_weights(sharded_param_dict, sharded_buffer_dict) return self.parallel_model - def apply_placement_pp(self, sharding_placement=None) -> torch.nn.Module: + +######################## +# Pipeline stuff start # +######################## +class AutoParallelPPModule(torch.nn.Module): + def __init__( + self, + sharded_param_dict: dict[str, torch.nn.Parameter], + sharded_buffer_dict: dict[str, torch.Tensor], + init_weights_model: torch.nn.Module, + ): + super().__init__() + self._register_params_and_buffers(sharded_param_dict, sharded_buffer_dict) + + # Right now we require a convention that the user model provides an init_weights method, + # although we could snoop for other methods too. + if hasattr(init_weights_model, "init_weights"): + hook_params_setters(init_weights_model, self) + + def init_weights(_self, *args, **kwargs): + # this is now a deep-fake-copy of orig mod, so we don't have to use reparametrize + return init_weights_model.init_weights(*args, **kwargs) + + # assign an init_weights method onto the output mod. + # all it does is sneakily run the original user mod's init_weights method, + # but with our new DTensor sharded params attached to the user module. + self.init_weights = MethodType(init_weights, self) + + def _register_params_and_buffers(self, sharded_param_dict, sharded_buffer_dict): + + # We construct an unflattened structure on parallel_mod, + # e.g. _assign_attr(v, parallel_model, k="layers.0.weight") will literally + # create empty nn.Modules recursively and then stash 'v' so it shows up in the right spot + for k, v in sharded_param_dict.items(): + _assign_attr(v, self, k, attr_kind=_AttrKind.PARAMETER) + + for k, v in sharded_buffer_dict.items(): + _assign_attr(v, self, k, attr_kind=_AttrKind.BUFFER) + + def forward(self, *args): + raise NotImplementedError("This is a placeholder for the pipeline model") + + +class AutoParallelPP(AutoParallel): + def apply_placement_pp(self, sharding_placement=None) -> dict[str, Any]: sharded_param_dict, sharded_buffer_dict = self._apply_placement_common( sharding_placement ) - from autoparallel._passes.graph_partition import ( - partition_joint_with_descriptors, - ) - ( fw_module, bw_module, @@ -569,60 +602,62 @@ def apply_placement_pp(self, sharding_placement=None) -> torch.nn.Module: adjusted_flat_args, ) = partition_joint_with_descriptors(self.joint_with_descriptors) - class AutoParallelPPStage(torch.autograd.Function): - @staticmethod - def forward(ctx, *args): - fw_module = args[-2] - ctx.bw_module = args[-1] - fw_args = list(args[:-2]) - fw_outputs = torch.fx.Interpreter(fw_module).boxed_run(fw_args) - num_inner_fwd_outputs = num_mutate_inputs + num_user_outputs - saved_intermediates = fw_outputs[num_inner_fwd_outputs:] - num_tensors_for_backward = ( - len(saved_intermediates) - num_symints_saved_for_bw - ) - tensors_to_save = saved_intermediates[:num_tensors_for_backward] - non_tensors_to_save = saved_intermediates[num_tensors_for_backward:] - ctx.save_for_backward(*tensors_to_save) - ctx.non_tensors = non_tensors_to_save - - user_outputs = fw_outputs[num_mutate_inputs:num_inner_fwd_outputs] - return user_outputs - - @staticmethod - def backward(ctx, *tangents): - bw_args = [*ctx.non_tensors, *ctx.saved_tensors, *tangents] - bw_outputs = torch.fx.Interpreter(ctx.bw_module).boxed_run(bw_args) - result = bw_outputs + (None,) * 2 - return result - - class AutoParallelPPModule(torch.nn.Module): - def __init__(self, fw_module, bw_module): - super().__init__() - self.fw_module = fw_module - self.bw_module = bw_module + print( + f"num_user_outputs: {num_user_outputs}\n" + f"num_mutate_inputs: {num_mutate_inputs}\n" + f"num_fw_outs_saved_for_bw: {num_fw_outs_saved_for_bw}\n" + f"num_symints_saved_for_bw: {num_symints_saved_for_bw}" + ) - def forward(self, *args): - # NB: don't close over the parameters/buffers, as the user may - # reassign the module! - # prepare_aot_module_simplified, this seems like an API gap - params_and_buffers = [ - v.to_local() - for k, v in itertools.chain( - dict(self.named_parameters(remove_duplicate=False)).items(), - dict(self.named_buffers(remove_duplicate=False)).items(), - ) - ] - boxed_args = [ - *params_and_buffers, - *args, - self.fw_module, - self.bw_module, - ] - del params_and_buffers - out = AutoParallelPPStage.apply(*boxed_args) - return out + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "autoparallel_pp_fwd_graph", + "encoding": "string", + }, + payload_fn=lambda: fw_module.print_readable( + print_output=False, include_stride=True, include_device=True + ), + ) + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "autoparallel_pp_bwd_graph", + "encoding": "string", + }, + payload_fn=lambda: bw_module.print_readable( + print_output=False, include_stride=True, include_device=True + ), + ) - self.parallel_model = AutoParallelPPModule(fw_module, bw_module) - self._register_params_and_init_weights(sharded_param_dict, sharded_buffer_dict) - return self.parallel_model + graph_meta: dict[str, int] = { + "num_mutate_inputs": num_mutate_inputs, + "num_user_outputs": num_user_outputs, + "num_symints_saved_for_bw": num_symints_saved_for_bw, + "num_weight_buffer_grads": len(sharded_param_dict) + + len(sharded_buffer_dict), + } + graph_modules: dict[str, Optional[torch.fx.GraphModule]] = { + "fw": fw_module, + "full_bw": bw_module, + "bw_dI": None, + "bw_dW": None, + "unshard": None, + "reduce_grad": None, + } + self.parallel_model = AutoParallelPPModule( + sharded_param_dict, + sharded_buffer_dict, + self.init_weights_model, + ) + return { + "graph_callables": graph_modules, + "graph_meta": graph_meta, + "sharded_param_dict": sharded_param_dict, + "sharded_buffer_dict": sharded_buffer_dict, + } + + +###################### +# Pipeline stuff end # +###################### diff --git a/autoparallel/graph_pp_runner.py b/autoparallel/graph_pp_runner.py new file mode 100644 index 0000000..a4ff630 --- /dev/null +++ b/autoparallel/graph_pp_runner.py @@ -0,0 +1,417 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import Any, Callable, Optional, Union, cast + +import torch +import torch.fx as fx +from torch.distributed.pipelining.schedules import ( + _Action, + _PipelineContext, + _PipelineScheduleRuntime, + _wait_batch_p2p, +) +from torch.distributed.pipelining.stage import ( + PipelineStage, + _normalize_model_output_as_tuple, +) +from torch.distributed.tensor import DTensor + + +@dataclass +class GraphCallables: + fw: fx.GraphModule + full_bw: fx.GraphModule + bw_dI: Optional[fx.GraphModule] = None + bw_dW: Optional[fx.GraphModule] = None + unshard: Optional[fx.GraphModule] = None + reduce_grad: Optional[fx.GraphModule] = None + + +@dataclass +class GraphMeta: + num_mutate_inputs: int + num_user_outputs: int + num_symints_saved_for_bw: int + num_weight_buffer_grads: int + + +class GraphPipelineStage(PipelineStage): + def __init__( + self, + submodule: torch.nn.Module, + graph_callables: GraphCallables, + graph_meta: GraphMeta, + stage_index: int, + num_stages: int, + device: torch.device, + input_args: Optional[Union[torch.Tensor, tuple[torch.Tensor, ...]]] = None, + output_args: Optional[Union[torch.Tensor, tuple[torch.Tensor, ...]]] = None, + group: Optional[torch.distributed.ProcessGroup] = None, + dw_builder: Optional[Callable[[], Callable[..., None]]] = None, + ): + super().__init__( + submodule=submodule, + stage_index=stage_index, + num_stages=num_stages, + device=device, + input_args=input_args, + output_args=output_args, + group=group, + dw_builder=dw_builder, + ) + self.graph_callables = graph_callables + self.graph_meta = graph_meta + self.state: dict[str, list[Any]] = { + "sharded_params": [], + "unsharded_params": [], + "buffers": [], + "sharded_grads": [], + "unsharded_grads": [], + } + + +def _run_fw_module( + fw_module: fx.GraphModule, graph_meta: GraphMeta, fw_args: list[Any] +) -> tuple[Any, tuple[Any, Any]]: + assert len([n for n in fw_module.graph.nodes if n.op == "placeholder"]) == len( + fw_args + ), f"Mismatched number of inputs to fwd, {len([n for n in fw_module.graph.nodes if n.op == 'placeholder'])}, {len(fw_args)}" + fw_outputs = torch.fx.Interpreter(fw_module).boxed_run(fw_args) + num_inner_fwd_outputs = graph_meta.num_mutate_inputs + graph_meta.num_user_outputs + saved_intermediates = fw_outputs[num_inner_fwd_outputs:] + num_tensors_for_backward = ( + len(saved_intermediates) - graph_meta.num_symints_saved_for_bw + ) + tensors_for_backward = saved_intermediates[:num_tensors_for_backward] + non_tensors_for_backward = saved_intermediates[num_tensors_for_backward:] + save_for_backward = (tensors_for_backward, non_tensors_for_backward) + user_outputs = fw_outputs[graph_meta.num_mutate_inputs : num_inner_fwd_outputs] + if len(user_outputs) == 1: + user_outputs = user_outputs[0] + return user_outputs, save_for_backward + + +def _run_full_bw_module( + bw_module: fx.GraphModule, graph_meta: GraphMeta, bw_args +) -> tuple[Any, list[Any]]: + assert len([n for n in bw_module.graph.nodes if n.op == "placeholder"]) == len( + bw_args + ), "Mismatched number of inputs to bwd" + bw_outputs = torch.fx.Interpreter(bw_module).boxed_run(bw_args) + param_buffer_grads = bw_outputs[: graph_meta.num_weight_buffer_grads] + input_grads = bw_outputs[graph_meta.num_weight_buffer_grads :] + return input_grads, param_buffer_grads + + +def _run_forward_microbatch(stage: GraphPipelineStage, *args) -> tuple[Any, Any]: + fw_args = [ + *stage.state["unsharded_params"], + *stage.state["buffers"], + *args, + ] + user_outputs, saved_intermediates = _run_fw_module( + stage.graph_callables.fw, stage.graph_meta, fw_args + ) + return (user_outputs, saved_intermediates) + + +def _run_backward_microbatch( + backward_stage: GraphPipelineStage, bwd_kwargs: dict[str, Any] +): + tangents = bwd_kwargs["tangents"] + saved_intermediates = bwd_kwargs["saved_intermediates"] + tensors_for_backward, non_tensors_for_backward = saved_intermediates + + bw_args = [ + *non_tensors_for_backward, + *tensors_for_backward, + *tangents, + ] + input_grads, param_buffer_grads = _run_full_bw_module( + backward_stage.graph_callables.full_bw, backward_stage.graph_meta, bw_args + ) + + unsharded_grads = backward_stage.state["unsharded_grads"] + grads_to_accumulate = param_buffer_grads[ + : len(backward_stage.state["sharded_params"]) + ] + assert len(unsharded_grads) == len(grads_to_accumulate) + for unsharded_grad, grad_to_accumulate in zip(unsharded_grads, grads_to_accumulate): + if grad_to_accumulate is not None: + if unsharded_grad is None: + unsharded_grad = grad_to_accumulate + else: + unsharded_grad += grad_to_accumulate + return input_grads + + +def stage_forward( + action: _Action, + ctx: _PipelineContext, +) -> None: + schedule = ctx.schedule_ref + assert isinstance(schedule, _PipelineScheduleRuntime) + stage_index_to_stage: dict[int, GraphPipelineStage] = { + stage.stage_index: cast(GraphPipelineStage, stage) for stage in schedule._stages + } + stage = stage_index_to_stage[action.stage_index] + stage_index = stage.stage_index + + mb_index = action.microbatch_index + assert mb_index is not None + fwd_recv_ops = schedule.fwd_recv_ops + arg_mbs = ctx.arg_mbs + kwarg_mbs = ctx.kwarg_mbs + + is_next_stage_on_this_rank = stage_index + 1 in stage_index_to_stage + is_prev_stage_on_this_rank = stage_index - 1 in stage_index_to_stage + + if ( + not stage.is_first + # no recv op expected for V-schedule special case (see [Note: V-schedule special case]) + and not is_prev_stage_on_this_rank + ): + assert ( + stage_index, + mb_index, + ) in fwd_recv_ops, f"Computing {action=} before receiving input" + + _wait_batch_p2p(fwd_recv_ops.pop((stage_index, mb_index))) + + args = arg_mbs[mb_index] # type: ignore[index] + kwargs = kwarg_mbs[mb_index] # type: ignore[index] + assert not kwargs # TODO: if kwargs can always be ignored, maybe remove? + + if stage.is_first: + # First stage doesn't need to receive anything + composite_args = args + else: + # Receive activations for this chunk + # Activations only come in args form + composite_args = stage._retrieve_recv_activations(mb_index) + + # stage._validate_fwd_input(args, kwargs) Maybe need to validate composite args? + + output, saved_intermediates = _run_forward_microbatch(stage, *composite_args) + + # See [Note: pipeline model output type] + output_tuple = _normalize_model_output_as_tuple(output) + + # Prepare for final output merge or reduction + # Output chunks is only used for the last stage since we only merge the output of the last stage + if stage.is_last: + stage.output_chunks.append(output) + + stage.fwd_cache[mb_index] = ( + output_tuple, # stage_output + saved_intermediates, # saved_intermediates + ) + + # stage._validate_fwd_outputs(output_tuple) + + schedule._maybe_compute_loss(stage, output, ctx.target_mbs, mb_index) + + # SEND/RECV op are avoided for special case with 2 adjacent stages on same rank + # see [Note: V-schedule special case] + if is_next_stage_on_this_rank: + stage_index_to_stage[stage_index + 1].set_local_fwd_input(output, mb_index) + + +def stage_full_backward( + action: _Action, + ctx: _PipelineContext, +) -> None: + schedule = ctx.schedule_ref + assert isinstance(schedule, _PipelineScheduleRuntime) + stage_index_to_stage: dict[int, GraphPipelineStage] = { + stage.stage_index: cast(GraphPipelineStage, stage) for stage in schedule._stages + } + + backward_stage_index = action.stage_index + backward_stage = stage_index_to_stage[backward_stage_index] + backward_mb_index = action.microbatch_index + assert backward_mb_index is not None + bwd_recv_ops = schedule.bwd_recv_ops + is_next_stage_on_this_rank = backward_stage.stage_index + 1 in stage_index_to_stage + is_prev_stage_on_this_rank = backward_stage.stage_index - 1 in stage_index_to_stage + + if ( + not backward_stage.is_last + # no recv op expected for V-schedule special case (see [Note: V-schedule special case]) + and not is_next_stage_on_this_rank + ): + assert ( + backward_stage_index, + backward_mb_index, + ) in bwd_recv_ops, f"Attempted to run compute {action=} before receiving input" + _wait_batch_p2p(bwd_recv_ops.pop((backward_stage_index, backward_mb_index))) + + loss = schedule._maybe_get_loss(backward_stage, backward_mb_index) + schedule.backward_counter[backward_stage_index] += 1 + last_backward = ( + schedule.backward_counter[backward_stage_index] == schedule._n_microbatches + ) + grad_scale_factor = schedule._n_microbatches if schedule.scale_grads else 1 + + if not backward_stage.has_backward: + return + ( + stage_output, + saved_intermediates, + ) = backward_stage.fwd_cache.pop(backward_mb_index) + + # Compute backward + if backward_stage.is_last: + # Last stage computes gradients from loss and has no gradients from + # next stage + # TODO(sanketpurandare) + # HACK till we have loss function, we populate the tangents here manually + bwd_kwargs = { + "stage_output": loss, + "tangents": [torch.randn_like(stage_output)], + "saved_intermediates": saved_intermediates, + } + else: + # Otherwise, receive gradients from next stage + output_grads = backward_stage._retrieve_recv_grads(backward_mb_index) + # If an input to the pipeline requires gradient, + # `torch.autograd.backward` will accumulate the gradient into the + # `.grad` field of such input + bwd_kwargs = { + "stage_output": stage_output, + "tangents": output_grads, + "saved_intermediates": saved_intermediates, + } + + input_grads = _run_backward_microbatch(backward_stage, bwd_kwargs) + + backward_stage.bwd_cache[backward_mb_index] = input_grads + + # skipping detach logic + + if last_backward: + backward_stage.scale_grads(grad_scale_factor) + # SEND/RECV op are avoided for special case with 2 adjacent stages on same rank + # see [Note: V-schedule special case] + if is_prev_stage_on_this_rank: + stage_index_to_stage[backward_stage_index - 1].set_local_bwd_input( + backward_stage.get_local_bwd_output(backward_mb_index), + backward_mb_index, + ) + + +def stage_unshard( + action: _Action, + ctx: _PipelineContext, +) -> None: + schedule = ctx.schedule_ref + assert isinstance(schedule, _PipelineScheduleRuntime) + stage_index_to_stage: dict[int, GraphPipelineStage] = { + stage.stage_index: cast(GraphPipelineStage, stage) for stage in schedule._stages + } + stage = stage_index_to_stage[action.stage_index] + if stage.graph_callables.unshard is None: + stage.state["unsharded_params"] = stage.state["sharded_params"] + # TODO (sanketpurandare): Add the fw_fsdp_all_gather graph call here + + +def stage_reshard( + action: _Action, + ctx: _PipelineContext, +): + schedule = ctx.schedule_ref + assert isinstance(schedule, _PipelineScheduleRuntime) + stage_index_to_stage: dict[int, GraphPipelineStage] = { + stage.stage_index: cast(GraphPipelineStage, stage) for stage in schedule._stages + } + stage = stage_index_to_stage[action.stage_index] + stage.state["unsharded_params"].clear() + + +def stage_reduce_grad( + action: _Action, + ctx: _PipelineContext, +) -> None: + schedule = ctx.schedule_ref + assert isinstance(schedule, _PipelineScheduleRuntime) + stage_index_to_stage: dict[int, GraphPipelineStage] = { + stage.stage_index: cast(GraphPipelineStage, stage) for stage in schedule._stages + } + stage = stage_index_to_stage[action.stage_index] + if stage.graph_callables.reduce_grad is None: + stage.state["sharded_grads"] = stage.state["unsharded_grads"] + + +class GraphPPRunner: + def __init__( + self, + schedule: _PipelineScheduleRuntime, + ): + self.schedule = schedule + + def _populate_stage_states(self, stage: GraphPipelineStage) -> None: + sharded_params = [ + v.to_local() if isinstance(v, DTensor) else v + for k, v in dict( + stage.submod.named_parameters(remove_duplicate=False) + ).items() + ] + buffers = [ + v.to_local() if isinstance(v, DTensor) else v + for k, v in dict(stage.submod.named_buffers(remove_duplicate=False)).items() + ] + stage.state["sharded_params"] = sharded_params + stage.state["buffers"] = buffers + stage.state["unsharded_grads"] = [None] * len(sharded_params) + # TODO (sanketpurandare) + # pipeline schedule runtime does not allow us to register a custom function + # for UNSHARD/RESHARD/REDUCE_GRAD action types yet + # HACK remove this once we support this + if stage.graph_callables.unshard is None: + stage.state["unsharded_params"] = stage.state["sharded_params"] + + def _accumulate_stage_grads_and_clear_states( + self, stage: GraphPipelineStage + ) -> None: + # TODO (sanketpurandare) + # We don't have a REDUCE_GRAD action yet in the ScheduleIR yet + # HACK remove this once Ivan's PR lands + if stage.graph_callables.reduce_grad is None: + stage.state["sharded_grads"] = stage.state["unsharded_grads"] + grads = stage.state["sharded_grads"] + params = list(stage.submod.parameters()) + for param, grad in zip(params, grads): + if param.requires_grad and grad is not None: + assert isinstance(grad, torch.Tensor) + if isinstance(param, DTensor): + param_spec = param._spec + _grad = DTensor.from_local( + grad, + device_mesh=param_spec.device_mesh, + placements=param_spec.placements, + shape=param_spec.shape, + stride=param_spec.stride, + ) + else: + _grad = grad # type: ignore[assignment] + if param.grad is None: + param.grad = _grad + else: + param.grad += _grad + stage.state.clear() + + def step(self, *args, **kwargs) -> None: + + for stage in self.schedule._stages: + assert isinstance(stage, GraphPipelineStage) + self._populate_stage_states(stage) + + self.schedule.step(*args, **kwargs) + + for stage in self.schedule._stages: + assert isinstance(stage, GraphPipelineStage) + self._accumulate_stage_grads_and_clear_states(stage) diff --git a/examples/example_ds3_local_map.py b/examples/example_ds3_local_map.py index 58acd4d..590bf2e 100644 --- a/examples/example_ds3_local_map.py +++ b/examples/example_ds3_local_map.py @@ -18,7 +18,7 @@ # must symbolically evaluate to run on 32 dp ranks # world_size = 2048 -fake_evaluate = False +fake_evaluate = True world_size = 256 diff --git a/examples/example_ds3_pp.py b/examples/example_ds3_pp.py new file mode 100644 index 0000000..fae167b --- /dev/null +++ b/examples/example_ds3_pp.py @@ -0,0 +1,464 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +from contextlib import nullcontext +from typing import Callable + +import torch +import torch.distributed._tools.fake_collectives +import torch.nn as nn +from torch._logging import trace_structured +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.distributed.pipelining.schedules import ( + FORWARD, + FULL_BACKWARD, + PipelineScheduleMulti, + _PipelineSchedule, + _PipelineScheduleRuntime, + get_schedule_class, +) +from torch.distributed.pipelining.stage import PipelineStage +from torch.distributed.tensor.placement_types import Shard +from torch.fx.experimental.symbolic_shapes import ShapeEnv +from torch.testing._internal.distributed.fake_pg import FakeStore + +from autoparallel._testing.models.dsv3 import ( + DeepSeekV3Model, + DeepSeekV3ModelArgs, + DeepSeekV3Stage0, + DeepSeekV3StageI, + DeepSeekV3StageN, + MoEArgs, +) +from autoparallel.api import AutoParallelPP +from autoparallel.graph_pp_runner import ( + GraphCallables, + GraphMeta, + GraphPipelineStage, + GraphPPRunner, + stage_forward, + stage_full_backward, +) + +logger = logging.getLogger(__name__) + + +def build_pipeline_schedule( + stages: list[PipelineStage], + loss_fn: Callable, + pipeline_parallel_schedule: str, + microbatch_size: int, + local_batch_size: int, + pipeline_parallel_degree: int, +) -> _PipelineSchedule: + """Builds a pipeline schedule for the given configuration and stages.""" + schedule_class = get_schedule_class(pipeline_parallel_schedule) + + looped_schedule = issubclass(schedule_class, PipelineScheduleMulti) + assert looped_schedule, "Only looped schedules are supported" + # validate that the batch size is divisible by the microbatch_size otherwise we'll hang or error during training + if local_batch_size % microbatch_size != 0: + raise ValueError( + f"Batch size {local_batch_size} must be divisible by {microbatch_size=}. " + ) + n_microbatches = local_batch_size // microbatch_size + # We expect that the number of local stages (`len(stages)`) is the same across all pp ranks + num_total_stages = pipeline_parallel_degree * len(stages) + if n_microbatches < num_total_stages: + logger.warning( + f"Number of microbatches ({n_microbatches}) is less than the total number " + f"of stages ({num_total_stages}) which may result in a bubble in the pipeline." + ) + + schedule = schedule_class( + stages if looped_schedule else stages[0], + n_microbatches=n_microbatches, + loss_fn=loss_fn, + ) + logger.info( + f"Using pipeline schedule {pipeline_parallel_schedule} " + f"with {n_microbatches} microbatches and {num_total_stages} stages." + ) + return schedule + + +def run_test(fake_evaluate: bool = False, use_fake_pg: bool = True): + if not use_fake_pg: + # TODO(sankepurandare): Come back to this later + torch.distributed.init_process_group() + assert "WORLD_SIZE" in os.environ, "run with torchrun --nproc-per-node 4" + world_size = int(os.getenv("WORLD_SIZE")) + pp_degree = 2 + dp_mod_ep_degree = 2 + ep_degree = 2 + dp_degree = dp_mod_ep_degree * ep_degree + assert ( + world_size == pp_degree * dp_mod_ep_degree * ep_degree + ), "world_size must be pp * dp * ep" + world_mesh = torch.distributed.device_mesh.init_device_mesh( + "cuda", + (pp_degree, dp_mod_ep_degree, ep_degree), + mesh_dim_names=( + "pp", + "dp_mod_ep", + "ep", + ), + ) + rank = int(os.getenv("RANK")) + local_rank = int(os.getenv("LOCAL_RANK")) + device = torch.device(f"cuda:{local_rank}") + pp_rank = world_mesh["pp"].get_local_rank() + else: + rank = int(os.getenv("RANK")) + pp_degree = 4 + dp_mod_ep_degree = 4 + ep_degree = 64 + dp_degree = dp_mod_ep_degree * ep_degree + world_size = pp_degree * dp_mod_ep_degree * ep_degree + + pp_rank = rank + device = torch.device(f"cuda:{pp_rank}") + + fake_store = FakeStore() + torch.distributed.init_process_group( + "fake", + store=fake_store, + rank=rank * dp_degree, # global rank is pp_rank * spmd_size + world_size=world_size, + ) + # mesh = torch.distributed.device_mesh.init_device_mesh("cuda", (world_size,), mesh_dim_names=("dp",)) + world_mesh = torch.distributed.device_mesh.init_device_mesh( + "cuda", + (pp_degree, dp_mod_ep_degree, ep_degree), + mesh_dim_names=( + "pp", + "dp_mod_ep", + "ep", + ), + ) + + print(f"PP rank: {pp_rank}") + + stages_per_rank = 2 + logical_pp_degree = pp_degree * stages_per_rank + + # This is the spmd mesh to be used for tracing + mesh = world_mesh[("dp_mod_ep", "ep")] + + global_batch_size = 32 * dp_degree + # Batch size that will be supplied to the schedule and will be broken down into microbatches + local_batch_size = global_batch_size // dp_degree + n_microbatches = 16 + # Batch size with which the spmd graphs will actually be executed + microbatch_size = local_batch_size // n_microbatches + assert ( + microbatch_size >= 1 + ), f"invalid config {local_batch_size=}, {n_microbatches=}" + # Batch size to be used for spmd tracing + spmd_batch_size = microbatch_size * dp_degree + + seq_len = 1024 + + config = DeepSeekV3ModelArgs( + vocab_size=102400, + max_seq_len=seq_len, + dim=2048, + inter_dim=10944, + moe_inter_dim=1408, + n_layers=8, # 27, + n_dense_layers=0, # 1, + n_heads=16, + moe_args=MoEArgs( + num_experts=64, + num_shared_experts=2, + top_k=6, + score_func="softmax", + route_norm=False, + score_before_experts=False, + mesh=mesh, + ), + q_lora_rank=0, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + mscale=0.70, + use_flex_attn=False, + attn_mask_type="causal", + ) + + with torch.device("meta"): + model = DeepSeekV3Model(config).bfloat16() + embed, layers, norm, output = list(model.children()) + items = list(layers.items()) + assert len(items) == config.n_layers + n_layers_per_rank = len(items) // logical_pp_degree + layers = [ + nn.ModuleDict(items[i : i + n_layers_per_rank]) + for i in range(0, len(items), n_layers_per_rank) + ] + assert len(layers) == logical_pp_degree + for lst in layers: + assert len(lst) * len(layers) == config.n_layers + + def tracing_input_fn(): + return torch.randint( + 0, + config.vocab_size, + (spmd_batch_size, seq_len), + device=device, + ) + + def tracing_input_fn_after_first_stage(): + return torch.randn( + (spmd_batch_size, seq_len, config.dim), + device=device, + dtype=torch.bfloat16, + requires_grad=True, + ) + + def runtime_input_fn(): + return torch.randint( + 0, + config.vocab_size, + (local_batch_size, seq_len), + device=device, + ) + + def shape_inference_input_fn(): + return torch.randint( + 0, + config.vocab_size, + (microbatch_size, seq_len), + device="meta", + ) + + def shape_inference_input_fn_after_first_stage(): + return torch.randn( + (microbatch_size, seq_len, config.dim), + device="meta", + dtype=torch.bfloat16, + requires_grad=True, + ) + + def shape_inference_output_fn_last_stage(): + return torch.randn( + (microbatch_size, seq_len, config.vocab_size), + device="meta", + dtype=torch.bfloat16, + requires_grad=True, + ) + + # Step 1. Construct the logical pipeline stages + with torch.device("meta"): + stage0 = DeepSeekV3Stage0(embed, layers[0], config) + stage1 = DeepSeekV3StageI(layers[1], config) + stage2 = DeepSeekV3StageI(layers[2], config) + stage3 = DeepSeekV3StageI(layers[3], config) + stage4 = DeepSeekV3StageI(layers[4], config) + stage5 = DeepSeekV3StageI(layers[5], config) + stage6 = DeepSeekV3StageI(layers[6], config) + stage7 = DeepSeekV3StageN(layers[7], norm, output, config) + logical_stages = [ + stage0, + stage1, + stage2, + stage3, + stage4, + stage5, + stage6, + stage7, + ] + # Step 2. Assign each logical stage(s) to pp ranks + # This mapping is dependent of the number of logical pipeline stages, the pp_degree and the schedule + # For interleaved 1F1B, the mapping is: + # pp_rank_to_stage_indices = { + # 0: [0, 4], + # 1: [1, 5], + # 2: [2, 6], + # 3: [3, 7], + # } + # For DualPipeV, the mapping is: + # pp_rank_to_stage_indices = { + # 0: [0, 7], + # 1: [1, 6], + # 2: [2, 5], + # 3: [3, 4], + # } + pp_rank_to_stage_indices: dict[int, list[int]] = { + 0: [0, 4], + 1: [1, 5], + 2: [2, 6], + 3: [3, 7], + } + assert len(pp_rank_to_stage_indices) == pp_degree + for stages in pp_rank_to_stage_indices.values(): + assert len(stages) * pp_degree == len(logical_stages) + stage_indices_current_pp_rank = pp_rank_to_stage_indices[pp_rank] + stage_mods: dict[int, torch.nn.Module] = {} + stage_graphs: dict[int, GraphCallables] = {} + stage_graph_metas: dict[int, GraphMeta] = {} + # Step 3. Apply AutoParallel to each logical stage assigned to this pp rank + use_cache = True + root_cache = "tmp" + os.makedirs(root_cache, exist_ok=True) + from autoparallel.api import AutoParallelPPModule + + for stage_idx in stage_indices_current_pp_rank: + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": f"begin_tracing_stage_{stage_idx}", + "encoding": "string", + }, + payload_fn=lambda: "placeholder text", + ) + stage_mod = logical_stages[stage_idx] + stage_file = os.path.join(root_cache, f"stage_{stage_idx}.pth") + if os.path.exists(stage_file) and use_cache: + cache = torch.load(stage_file, weights_only=False) + graph_callables = cache["graph_callables"] + graph_meta = cache["graph_meta"] + cache["sharded_param_dict"] = { + k: nn.Parameter(v.detach()) + for k, v in cache["sharded_param_dict"].items() + } + pp_mod = AutoParallelPPModule( + cache["sharded_param_dict"], cache["sharded_buffer_dict"], stage_mod + ) + else: + if stage_idx == 0: + input_fn = tracing_input_fn + else: + input_fn = tracing_input_fn_after_first_stage + with AutoParallelPP(stage_mod, input_fn, mesh, dynamic=True) as autop: + autop.add_parameter_memory_constraint(low=None, high=None) + + # x_sharding = (Shard(0), Replicate()) + x_sharding = (Shard(0), Shard(0)) + + autop.add_input_constraints([x_sharding]) + autop.add_output_constraints([x_sharding]) + + sharding_placement = autop.optimize_placement(verbose=False) + cache = autop.apply_placement_pp(sharding_placement) + graph_callables = cache["graph_callables"] + graph_meta = cache["graph_meta"] + pp_mod = AutoParallelPPModule( + cache["sharded_param_dict"], + cache["sharded_buffer_dict"], + autop.init_weights_model, + ) + if use_cache: + torch.save(cache, stage_file) + + torch.manual_seed(pp_rank) + pp_mod.to_empty(device=device) + pp_mod.init_weights(buffer_device=device) + + # Store each stage's information in stage_mods, stage_graphs, and stage_graph_metas + stage_mods[stage_idx] = pp_mod + stage_graphs[stage_idx] = GraphCallables( + fw=graph_callables["fw"], + full_bw=graph_callables["full_bw"], + bw_dI=graph_callables["bw_dI"], + bw_dW=graph_callables["bw_dW"], + unshard=graph_callables["unshard"], + reduce_grad=graph_callables["reduce_grad"], + ) + stage_graph_metas[stage_idx] = GraphMeta( + num_mutate_inputs=graph_meta["num_mutate_inputs"], + num_user_outputs=graph_meta["num_user_outputs"], + num_symints_saved_for_bw=graph_meta["num_symints_saved_for_bw"], + num_weight_buffer_grads=graph_meta["num_weight_buffer_grads"], + ) + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": f"end_tracing_stage_{stage_idx}", + "encoding": "string", + }, + payload_fn=lambda: "placeholder text", + ) + + # Two stages per pp rank + assert ( + len(stage_indices_current_pp_rank) + == len(stage_mods) + == len(stage_graphs) + == len(stage_graph_metas) + ) + + # run weight init on our sharded DTensor params + + stages = [] + # Step 4. Construct pipeline stages for this pp_rank using the stage modules, graphs and metadata + for pp_stage_idx, pp_stage_mod in stage_mods.items(): + stage = GraphPipelineStage( + pp_stage_mod, + stage_graphs[pp_stage_idx], + stage_graph_metas[pp_stage_idx], + stage_index=pp_stage_idx, + num_stages=len(logical_stages), + device=device, + input_args=( + shape_inference_input_fn() + if pp_stage_idx == 0 + else shape_inference_input_fn_after_first_stage() + ), + output_args=( + shape_inference_output_fn_last_stage() + if pp_stage_idx == 7 + else shape_inference_input_fn_after_first_stage() + ), + group=world_mesh.get_group("pp"), + ) + stages.append(stage) + # Step 5. Construct the pipeline runner using the pipeline stages for this pp_rank + schedule = build_pipeline_schedule( + stages=stages, + loss_fn=None, + pipeline_parallel_schedule="Interleaved1F1B", + microbatch_size=microbatch_size, + local_batch_size=local_batch_size, + pipeline_parallel_degree=pp_degree, + ) + assert isinstance(schedule, _PipelineScheduleRuntime) + # Step 6. Override the pipeline runner's action implementations + schedule.register_custom_function(FORWARD, stage_forward) + schedule.register_custom_function(FULL_BACKWARD, stage_full_backward) + + # Step 7. Register the schedule with the graph runner + + graph_pp_runner = GraphPPRunner(schedule) + + # Step 8. Run the whole pipeline once using the graph runner + with ( + FakeTensorMode( + allow_non_fake_inputs=True, + shape_env=ShapeEnv(), + ) + if fake_evaluate + else nullcontext() + ): + with torch.no_grad(): + if pp_rank == 0: + x = runtime_input_fn() + graph_pp_runner.step(x) + else: + graph_pp_runner.step() + + print("All good!") + + if torch.distributed.is_initialized(): + torch.distributed.barrier() + torch.cuda.synchronize() + torch.distributed.destroy_process_group() + + +if __name__ == "__main__": + run_test(fake_evaluate=True, use_fake_pg=True) diff --git a/tests/test_graph_partition.py b/examples/example_pp_graph_partition.py similarity index 58% rename from tests/test_graph_partition.py rename to examples/example_pp_graph_partition.py index 3d6f324..66edd32 100644 --- a/tests/test_graph_partition.py +++ b/examples/example_pp_graph_partition.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. +import itertools from contextlib import nullcontext import torch @@ -11,16 +12,21 @@ from torch.fx.experimental.symbolic_shapes import ShapeEnv from torch.testing._internal.distributed.fake_pg import FakeStore +from autoparallel._passes.split_fsdp_collectives import ( + split_fsdp_prefetch, + split_fsdp_reduce_scatters_epilogue, +) from autoparallel._testing.models.dsv3 import ( DeepSeekV3Model, DeepSeekV3ModelArgs, MoEArgs, ) -from autoparallel.api import AutoParallel +from autoparallel.api import AutoParallelPP +from autoparallel.graph_pp_runner import GraphMeta, _run_full_bw_module, _run_fw_module # must symbolically evaluate to run on 32 dp ranks # world_size = 2048 -fake_evaluate = False +fake_evaluate = True world_size = 256 @@ -85,7 +91,7 @@ def input_fn(): ) -with AutoParallel(model, input_fn, mesh, dynamic=True) as autop: +with AutoParallelPP(model, input_fn, mesh, dynamic=True) as autop: autop.add_parameter_memory_constraint(low=None, high=None) # x_sharding = (Shard(0), Replicate()) @@ -95,7 +101,16 @@ def input_fn(): autop.add_output_constraints([x_sharding]) sharding_placement = autop.optimize_placement() - pp_mod = autop.apply_placement_pp(sharding_placement) + res = autop.apply_placement_pp(sharding_placement) + graph_callables = res["graph_callables"] + graph_meta = res["graph_meta"] + graph_meta = GraphMeta( + num_mutate_inputs=graph_meta["num_mutate_inputs"], + num_user_outputs=graph_meta["num_user_outputs"], + num_symints_saved_for_bw=graph_meta["num_symints_saved_for_bw"], + num_weight_buffer_grads=graph_meta["num_weight_buffer_grads"], + ) + pp_mod = autop.parallel_model pp_mod.to_empty(device="cuda") # run weight init on our sharded DTensor params @@ -104,6 +119,13 @@ def input_fn(): # init_std=0.02, buffer_device="cuda" # ) # maybe not correct value pp_mod.init_weights(buffer_device="cuda") + +fw_g = graph_callables["fw"].graph +bw_g = graph_callables["full_bw"].graph + +fw_unshard_g, fw_main_g = split_fsdp_prefetch(fw_g) +bw_main_g, bw_reduce_grad_g = split_fsdp_reduce_scatters_epilogue(bw_g) + x = ( torch.randint( 0, @@ -112,22 +134,46 @@ def input_fn(): device=torch.device("cuda"), ), ) - +params_buffers = [ + v.to_local() + for k, v in + # TODO: this is very slow + itertools.chain( + dict(pp_mod.named_parameters(remove_duplicate=False)).items(), + dict(pp_mod.named_buffers(remove_duplicate=False)).items(), + ) +] # Symbolically evaluate in case you want to test running a graph bigger than your gpu -mode: nullcontext[None] | FakeTensorMode = nullcontext() -if fake_evaluate: - mode = FakeTensorMode( +with ( + FakeTensorMode( allow_non_fake_inputs=True, shape_env=ShapeEnv(), ) - -with mode: + if fake_evaluate + else nullcontext() +): # # now let's run it - outputs = pp_mod(*x) - assert len(outputs) == 1 - output = outputs[0] - output.backward(torch.randn_like(output)) + with torch.no_grad(): + fw_args = [*params_buffers, *x] + output, saved_intermediates = _run_fw_module( + graph_callables["fw"], graph_meta, fw_args + ) + tangents = [torch.randn_like(output)] + tensors_for_backward, non_tensors_for_backward = saved_intermediates + + bw_args = [ + *non_tensors_for_backward, + *tensors_for_backward, + *tangents, + ] + + input_grads, param_buffer_grads = _run_full_bw_module( + graph_callables["full_bw"], graph_meta, bw_args + ) print("All good!") + +# Cleanup: destroy process group to allow other tests to initialize their own +torch.distributed.destroy_process_group()