diff --git a/.github/workflows/test_cuda.yml b/.github/workflows/test_cuda.yml index edca044..7f2d5c5 100644 --- a/.github/workflows/test_cuda.yml +++ b/.github/workflows/test_cuda.yml @@ -46,4 +46,4 @@ jobs: python examples/example_dcp.py python examples/example_local_map.py python examples/example_ds3_local_map.py - python examples/example_pp_graph_partition.py + python examples/example_pp_graph_passes.py diff --git a/autoparallel/_passes/graph_partition.py b/autoparallel/_passes/graph_partition.py index 736c5ba..87f894f 100644 --- a/autoparallel/_passes/graph_partition.py +++ b/autoparallel/_passes/graph_partition.py @@ -29,7 +29,15 @@ def partition_joint_with_descriptors( fw_compiler: Callable = boxed_nop_preserve_node_meta, bw_compiler: Callable = boxed_nop_preserve_node_meta, ) -> tuple[ - torch.fx.GraphModule, torch.fx.GraphModule, int, int, int, int, list[int], list[Any] + torch.fx.GraphModule, + torch.fx.GraphModule, + int, + int, + int, + int, + int, + list[int], + list[Any], ]: aot_state: AOTState = jd._aot_state aot_graph_capture: AOTGraphCapture = jd._aot_graph_capture @@ -79,9 +87,11 @@ def partition_joint_with_descriptors( num_mutate_inputs = len( [x for x in fw_metadata.input_info if x.mutates_data or x.mutates_metadata] ) + num_params_buffers = aot_config.num_params_buffers return ( fw_module, bw_module, + num_params_buffers, num_user_outputs, num_mutate_inputs, num_fw_outs_saved_for_bw, diff --git a/autoparallel/_passes/split_di_dw_graph.py b/autoparallel/_passes/split_di_dw_graph.py index d3d873a..ed6d6b9 100644 --- a/autoparallel/_passes/split_di_dw_graph.py +++ b/autoparallel/_passes/split_di_dw_graph.py @@ -191,7 +191,7 @@ def _extract_fwd_bwd_modules( # TODO: in theory we can infer num_weight_gradients from the graph metadata directly def split_di_dw_graph( - bw_gm_old: fx.GraphModule, *, num_weight_gradients + bw_gm_old: fx.GraphModule, *, num_weight_gradients: int ) -> tuple[fx.GraphModule, fx.GraphModule, int]: # we could consider doing this is a non-mutating way bw_gm = copy.deepcopy(bw_gm_old) diff --git a/autoparallel/_passes/split_fsdp_collectives.py b/autoparallel/_passes/split_fsdp_collectives.py index a411fb3..86e6a82 100644 --- a/autoparallel/_passes/split_fsdp_collectives.py +++ b/autoparallel/_passes/split_fsdp_collectives.py @@ -5,6 +5,7 @@ import dataclasses from contextlib import contextmanager +from copy import deepcopy from functools import partial from typing import Any @@ -49,12 +50,19 @@ class EpilogueInput(AOTOutput): pass -def split_fsdp_prefetch(g: torch.fx.Graph) -> tuple[torch.fx.Graph, torch.fx.Graph]: - g_ins = g.find_nodes(op="placeholder") +def split_fsdp_prefetch( + gm: torch.fx.GraphModule, + num_params: int, +) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule]: + g = deepcopy(gm.graph) + all_g_ins = g.find_nodes(op="placeholder") + param_g_ins = all_g_ins[:num_params] + rem_g_ins = all_g_ins[num_params:] + prefetch_g_outs_map = [] - for g_in in g_ins: - n = g_in + for param_g_in in param_g_ins: + n = param_g_in last_ag = None while True: if len(n.users) != 1: @@ -66,7 +74,7 @@ def split_fsdp_prefetch(g: torch.fx.Graph) -> tuple[torch.fx.Graph, torch.fx.Gra if is_all_gather_into_tensor(n): last_ag = n if last_ag is None: - prefetch_g_outs_map.append(g_in) + prefetch_g_outs_map.append(param_g_in) else: w_n = next(iter(last_ag.users)) prefetch_g_outs_map.append(w_n) @@ -82,7 +90,7 @@ def split_fsdp_prefetch(g: torch.fx.Graph) -> tuple[torch.fx.Graph, torch.fx.Gra with exclude_wait_from_fx_side_effectful(): prefetch_g = _extract_graph_with_inputs_outputs( g, - g_ins, + param_g_ins, prefetch_g_outs, prefetch_g_outs_descs, ignore_must_be_in_fw_bw=True, @@ -90,26 +98,34 @@ def split_fsdp_prefetch(g: torch.fx.Graph) -> tuple[torch.fx.Graph, torch.fx.Gra main_g = _extract_graph_with_inputs_outputs( g, - prefetch_g_outs, + prefetch_g_outs + rem_g_ins, g_outs, g_outs_descs, ignore_must_be_in_fw_bw=True, ) - return prefetch_g, main_g + prefetch_gm = torch.fx._lazy_graph_module._make_graph_module(gm, prefetch_g) + main_gm = torch.fx._lazy_graph_module._make_graph_module(gm, main_g) + return prefetch_gm, main_gm def split_fsdp_reduce_scatters_epilogue( - g: torch.fx.Graph, -) -> tuple[torch.fx.Graph, torch.fx.Graph]: + gm: torch.fx.GraphModule, + num_grads: int, +) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule]: + g = deepcopy(gm.graph) g_ins = g.find_nodes(op="placeholder") g_outs = pytree.arg_tree_leaves(*(n.args for n in g.find_nodes(op="output"))) - g_outs_descs = pytree.arg_tree_leaves( - next(iter(g.find_nodes(op="output"))).meta.get("desc", [None] * len(g_outs)) + grad_outs = g_outs[:num_grads] + rem_g_outs = g_outs[num_grads:] + out_descs = pytree.arg_tree_leaves( + next(iter(g.find_nodes(op="output"))).meta.get("desc", [None] * len(grad_outs)) ) + grad_outs_descs = out_descs[:num_grads] + rem_g_outs_descs = out_descs[num_grads:] - g_outs_map = [] - for g_out in g_outs: - n = g_out + grad_outs_map = [] + for grad_out in grad_outs: + n = grad_out last_rs = None while n is not None: if len(n.all_input_nodes) != 1: @@ -124,27 +140,28 @@ def split_fsdp_reduce_scatters_epilogue( # The reduction of gradients happen in multiple steps last_rs = n if last_rs is not None: - g_outs_map.append(last_rs) + grad_outs_map.append(last_rs) else: - g_outs_map.append(g_out) + grad_outs_map.append(grad_out) - epi_g_ins = [n for n in g_outs_map if n is not None] + epi_g_ins = grad_outs_map epi_g_ins_descs: list[AOTOutput] = [EpilogueInput() for _ in range(len(epi_g_ins))] with exclude_wait_from_fx_side_effectful(): main_g = _extract_graph_with_inputs_outputs( g, g_ins, - epi_g_ins, - epi_g_ins_descs, + epi_g_ins + rem_g_outs, + epi_g_ins_descs + rem_g_outs_descs, ignore_must_be_in_fw_bw=True, ) epi_g = _extract_graph_with_inputs_outputs( g, epi_g_ins, - g_outs, - g_outs_descs, + grad_outs, + grad_outs_descs, ignore_must_be_in_fw_bw=True, ) - - return main_g, epi_g + epi_gm = torch.fx._lazy_graph_module._make_graph_module(gm, epi_g) + main_gm = torch.fx._lazy_graph_module._make_graph_module(gm, main_g) + return main_gm, epi_gm diff --git a/autoparallel/api.py b/autoparallel/api.py index ffceae5..046b12f 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -588,14 +588,21 @@ def forward(self, *args): class AutoParallelPP(AutoParallel): def apply_placement_pp( - self, sharding_placement=None, generate_di_dw_split_graphs=False + self, sharding_placement=None, graph_passes: list[str] = [] ) -> dict[str, Any]: + assert all( + g_pass in ["split_fsdp_collectives", "split_dI_dW"] + for g_pass in graph_passes + ), "Only split_fsdp_collectives and split_dI_dW_graph are supported" sharded_param_dict, sharded_buffer_dict = self._apply_placement_common( sharding_placement ) + num_params = len(sharded_param_dict) + num_buffers = len(sharded_buffer_dict) ( fw_module, bw_module, + num_params_buffers, num_user_outputs, num_mutate_inputs, num_fw_outs_saved_for_bw, @@ -603,8 +610,11 @@ def apply_placement_pp( _indices_of_inps_to_detach, adjusted_flat_args, ) = partition_joint_with_descriptors(self.joint_with_descriptors) - + assert num_params_buffers == ( + num_params + num_buffers + ), f"num_params_buffers: {num_params_buffers}, num_params: {num_params}, num_buffers: {num_buffers}" print( + f"num_params_buffers: {num_params_buffers}\n" f"num_user_outputs: {num_user_outputs}\n" f"num_mutate_inputs: {num_mutate_inputs}\n" f"num_fw_outs_saved_for_bw: {num_fw_outs_saved_for_bw}\n" @@ -631,14 +641,71 @@ def apply_placement_pp( print_output=False, include_stride=True, include_device=True ), ) - if generate_di_dw_split_graphs: - from autoparallel._passes.split_di_dw_graph import split_di_dw_graph + unshard_module: Optional[torch.fx.GraphModule] = None + reduce_grad_module: Optional[torch.fx.GraphModule] = None + if "split_fsdp_collectives" in graph_passes: + assert ( + not self.reshard_after_forward + ), "reshard_after_forward should be False to disable FSDP all_gather in the backward pass" + from autoparallel._passes.split_fsdp_collectives import ( + split_fsdp_prefetch, + split_fsdp_reduce_scatters_epilogue, + ) - num_weight_gradients = ( - self.joint_with_descriptors._aot_state.aot_config.num_params_buffers + unshard_module, fw_module = split_fsdp_prefetch(fw_module, num_params) + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "autoparallel_pp_unshard_graph", + "encoding": "string", + }, + payload_fn=lambda: unshard_module.print_readable( + print_output=False, include_stride=True, include_device=True + ), ) + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "autoparallel_pp_fwd_no_fsdp_graph", + "encoding": "string", + }, + payload_fn=lambda: fw_module.print_readable( + print_output=False, include_stride=True, include_device=True + ), + ) + bw_module, reduce_grad_module = split_fsdp_reduce_scatters_epilogue( + bw_module, num_params + ) + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "autoparallel_pp_bwd_no_fsdp_graph", + "encoding": "string", + }, + payload_fn=lambda: bw_module.print_readable( + print_output=False, include_stride=True, include_device=True + ), + ) + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "autoparallel_pp_reduce_grad_graph", + "encoding": "string", + }, + payload_fn=lambda: reduce_grad_module.print_readable( + print_output=False, include_stride=True, include_device=True + ), + ) + + bw_dI_module: Optional[torch.fx.GraphModule] = None + bw_dW_module: Optional[torch.fx.GraphModule] = None + num_input_grads = 0 + if "split_dI_dW" in graph_passes: + from autoparallel._passes.split_di_dw_graph import split_di_dw_graph + bw_dI_module, bw_dW_module, num_input_grads = split_di_dw_graph( - bw_module, num_weight_gradients=num_weight_gradients + bw_module, + num_weight_gradients=num_params_buffers, ) trace_structured( "artifact", @@ -669,24 +736,23 @@ def apply_placement_pp( raise RuntimeError( "attempted to run split dI/dW pass on a graph that has no input gradients" ) - else: - bw_dI_module, bw_dW_module, num_input_grads = None, None, -1 graph_meta: dict[str, int] = { "num_mutate_inputs": num_mutate_inputs, "num_user_outputs": num_user_outputs, "num_symints_saved_for_bw": num_symints_saved_for_bw, - "num_weight_buffer_grads": len(sharded_param_dict) - + len(sharded_buffer_dict), + "num_params": num_params, + "num_buffers": num_buffers, "num_input_grads": num_input_grads, } + graph_modules: dict[str, Optional[torch.fx.GraphModule]] = { "fw": fw_module, "full_bw": bw_module, "bw_dI": bw_dI_module, "bw_dW": bw_dW_module, - "unshard": None, - "reduce_grad": None, + "unshard": unshard_module, + "reduce_grad": reduce_grad_module, } self.parallel_model = AutoParallelPPModule( sharded_param_dict, diff --git a/autoparallel/graph_pp_runner.py b/autoparallel/graph_pp_runner.py index 5741e90..6d2992c 100644 --- a/autoparallel/graph_pp_runner.py +++ b/autoparallel/graph_pp_runner.py @@ -36,7 +36,8 @@ class GraphMeta: num_mutate_inputs: int num_user_outputs: int num_symints_saved_for_bw: int - num_weight_buffer_grads: int + num_params: int + num_buffers: int num_input_grads: int @@ -77,7 +78,7 @@ def __init__( def _run_fw_module( fw_module: fx.GraphModule, graph_meta: GraphMeta, fw_args: list[Any] -) -> tuple[Any, tuple[Any, Any]]: +) -> tuple[Any, tuple[list[Any], list[Any]]]: assert len([n for n in fw_module.graph.nodes if n.op == "placeholder"]) == len( fw_args ), f"Mismatched number of inputs to fwd, {len([n for n in fw_module.graph.nodes if n.op == 'placeholder'])}, {len(fw_args)}" @@ -98,28 +99,58 @@ def _run_fw_module( def _run_full_bw_module( bw_module: fx.GraphModule, graph_meta: GraphMeta, bw_args -) -> tuple[Any, list[Any]]: +) -> tuple[list[Any], list[Any]]: assert len([n for n in bw_module.graph.nodes if n.op == "placeholder"]) == len( bw_args - ), "Mismatched number of inputs to bwd" + ), "Mismatched number of inputs to full bwd" bw_outputs = torch.fx.Interpreter(bw_module).boxed_run(bw_args) - param_buffer_grads = bw_outputs[: graph_meta.num_weight_buffer_grads] - input_grads = bw_outputs[graph_meta.num_weight_buffer_grads :] + num_params_buffers = graph_meta.num_params + graph_meta.num_buffers + param_buffer_grads = bw_outputs[:num_params_buffers] + input_grads = bw_outputs[num_params_buffers:] return input_grads, param_buffer_grads -def _run_split_bw_module( - bw_dI_gm: fx.GraphModule, bw_dW_gm: fx.GraphModule, graph_meta: GraphMeta, bw_args -) -> tuple[Any, list[Any]]: - assert len([n for n in bw_dI_gm.graph.nodes if n.op == "placeholder"]) == len( - bw_args - ), "Mismatched number of inputs to bwd" - inp_grads_and_activations = torch.fx.Interpreter(bw_dI_gm).boxed_run(bw_args) +def _run_dI_bw_module( + bw_dI_module: fx.GraphModule, graph_meta: GraphMeta, bw_dI_args +) -> tuple[list[Any], list[Any]]: + assert len([n for n in bw_dI_module.graph.nodes if n.op == "placeholder"]) == len( + bw_dI_args + ), "Mismatched number of inputs to dI bwd" + inp_grads_and_activations = torch.fx.Interpreter(bw_dI_module).boxed_run(bw_dI_args) inp_grads, activations = inp_grads_and_activations[ : graph_meta.num_input_grads ], list(inp_grads_and_activations[graph_meta.num_input_grads :]) - weight_grads = torch.fx.Interpreter(bw_dW_gm).boxed_run(activations) - return inp_grads, weight_grads + return inp_grads, activations + + +def _run_dW_bw_module( + bw_dW_module: fx.GraphModule, graph_meta: GraphMeta, bw_dW_args +) -> list[Any]: + assert len([n for n in bw_dW_module.graph.nodes if n.op == "placeholder"]) == len( + bw_dW_args + ), "Mismatched number of inputs to dW bwd" + param_buffer_grads = torch.fx.Interpreter(bw_dW_module).boxed_run(bw_dW_args) + return param_buffer_grads + + +def _run_unshard_module( + unshard_module: fx.GraphModule, graph_meta: GraphMeta, unshard_args +) -> list[Any]: + assert len([n for n in unshard_module.graph.nodes if n.op == "placeholder"]) == len( + unshard_args + ), "Mismatched number of inputs to unshard" + unsharded_params = torch.fx.Interpreter(unshard_module).boxed_run(unshard_args) + return unsharded_params + + +def _run_reduce_grad_module( + reduce_grad_module: fx.GraphModule, graph_meta: GraphMeta, reduce_grad_args +) -> list[Any]: + assert len( + [n for n in reduce_grad_module.graph.nodes if n.op == "placeholder"] + ) == len(reduce_grad_args), "Mismatched number of inputs to reduce_grad" + sharded_grads = torch.fx.Interpreter(reduce_grad_module).boxed_run(reduce_grad_args) + return sharded_grads def _run_forward_microbatch(stage: GraphPipelineStage, *args) -> tuple[Any, Any]: @@ -146,6 +177,7 @@ def _run_backward_microbatch( *tensors_for_backward, *tangents, ] + del tensors_for_backward, non_tensors_for_backward, tangents, saved_intermediates input_grads, param_buffer_grads = _run_full_bw_module( backward_stage.graph_callables.full_bw, backward_stage.graph_meta, bw_args ) @@ -155,6 +187,7 @@ def _run_backward_microbatch( : len(backward_stage.state["sharded_params"]) ] assert len(unsharded_grads) == len(grads_to_accumulate) + assert not all(grad is None for grad in grads_to_accumulate), "All grads are None" for unsharded_grad, grad_to_accumulate in zip(unsharded_grads, grads_to_accumulate): if grad_to_accumulate is not None: if unsharded_grad is None: diff --git a/examples/example_ds3_pp.py b/examples/example_ds3_pp.py index 156e124..02c8a86 100644 --- a/examples/example_ds3_pp.py +++ b/examples/example_ds3_pp.py @@ -373,7 +373,9 @@ def shape_inference_output_fn_last_stage(): num_mutate_inputs=graph_meta["num_mutate_inputs"], num_user_outputs=graph_meta["num_user_outputs"], num_symints_saved_for_bw=graph_meta["num_symints_saved_for_bw"], - num_weight_buffer_grads=graph_meta["num_weight_buffer_grads"], + num_params=graph_meta["num_params"], + num_buffers=graph_meta["num_buffers"], + num_input_grads=graph_meta["num_input_grads"], ) trace_structured( "artifact", diff --git a/examples/example_pp_graph_passes.py b/examples/example_pp_graph_passes.py new file mode 100644 index 0000000..05fa90b --- /dev/null +++ b/examples/example_pp_graph_passes.py @@ -0,0 +1,420 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from contextlib import nullcontext +from typing import Callable + +import torch +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.distributed.tensor import DeviceMesh, DTensor +from torch.distributed.tensor.placement_types import Shard +from torch.fx.experimental.symbolic_shapes import ShapeEnv +from torch.testing._internal.distributed.fake_pg import FakeStore + +from autoparallel._testing.models.dsv3 import ( + DeepSeekV3Model, + DeepSeekV3ModelArgs, + MoEArgs, +) +from autoparallel.api import AutoParallelPP +from autoparallel.graph_pp_runner import ( + GraphCallables, + GraphMeta, + _run_dI_bw_module, + _run_dW_bw_module, + _run_full_bw_module, + _run_fw_module, + _run_reduce_grad_module, + _run_unshard_module, +) + + +def _get_pp_module_and_graphs( + model: torch.nn.Module, + mesh: DeviceMesh, + tracing_input_fn: Callable, + graph_passes: list[str] = [], +) -> tuple[torch.nn.Module, GraphCallables, GraphMeta]: + + with AutoParallelPP( + model, tracing_input_fn, mesh, dynamic=True, reshard_after_forward=False + ) as autop: + autop.add_parameter_memory_constraint(low=None, high=None) + + # x_sharding = (Shard(0), Replicate()) + x_sharding = (Shard(0), Shard(0)) + + autop.add_input_constraints([x_sharding]) + autop.add_output_constraints([x_sharding]) + + sharding_placement = autop.optimize_placement() + res = autop.apply_placement_pp( + sharding_placement=sharding_placement, + graph_passes=graph_passes, + ) + pp_mod = autop.parallel_model + graph_callables = res["graph_callables"] + graph_modules = GraphCallables( + fw=graph_callables["fw"], + full_bw=graph_callables["full_bw"], + bw_dI=graph_callables["bw_dI"], + bw_dW=graph_callables["bw_dW"], + unshard=graph_callables["unshard"], + reduce_grad=graph_callables["reduce_grad"], + ) + graph_meta = res["graph_meta"] + graph_meta = GraphMeta( + num_mutate_inputs=graph_meta["num_mutate_inputs"], + num_user_outputs=graph_meta["num_user_outputs"], + num_symints_saved_for_bw=graph_meta["num_symints_saved_for_bw"], + num_params=graph_meta["num_params"], + num_buffers=graph_meta["num_buffers"], + num_input_grads=graph_meta["num_input_grads"], + ) + + pp_mod.to_empty(device="cuda") + pp_mod.init_weights(buffer_device="cuda") + return pp_mod, graph_modules, graph_meta + + +# graph_passes=["split_dI_dW", "split_fsdp_collectives"], + + +def _get_fw_inputs( + pp_mod: torch.nn.Module, eval_input_fn: Callable +) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: + x: list[torch.Tensor] = [ + eval_input_fn(), + ] + sharded_params = [ + v.to_local() if isinstance(v, DTensor) else v + for k, v in dict(pp_mod.named_parameters(remove_duplicate=False)).items() + ] + buffers = [ + v.to_local() if isinstance(v, DTensor) else v + for k, v in dict(pp_mod.named_buffers(remove_duplicate=False)).items() + ] + return [sharded_params, buffers, x] + + +# Symbolically evaluate in case you want to test running a graph bigger than your gpu + + +def test_graph_partition( + model: torch.nn.Module, + mesh: DeviceMesh, + tracing_input_fn: Callable, + eval_input_fn: Callable, + fake_evaluate: bool = True, +): + + pp_mod, graph_modules, graph_meta = _get_pp_module_and_graphs( + model, mesh, tracing_input_fn + ) + sharded_params, buffers, x = _get_fw_inputs(pp_mod, eval_input_fn) + with ( + FakeTensorMode( + allow_non_fake_inputs=True, + shape_env=ShapeEnv(), + ) + if fake_evaluate + else nullcontext() + ): + # # now let's run it + with torch.no_grad(): + fw_args = [*sharded_params, *buffers, *x] + output, saved_intermediates = _run_fw_module( + graph_modules.fw, graph_meta, fw_args + ) + tangents = [torch.randn_like(output)] + tensors_for_backward, non_tensors_for_backward = saved_intermediates + + bw_args = [ + *non_tensors_for_backward, + *tensors_for_backward, + *tangents, + ] + del ( + tensors_for_backward, + non_tensors_for_backward, + tangents, + saved_intermediates, + ) + + input_grads, param_buffer_grads = _run_full_bw_module( + graph_modules.full_bw, graph_meta, bw_args + ) + + print("All good!") + + +def test_split_fsdp_collectives( + model: torch.nn.Module, + mesh: DeviceMesh, + tracing_input_fn: Callable, + eval_input_fn: Callable, + fake_evaluate: bool = True, +): + + pp_mod, graph_modules, graph_meta = _get_pp_module_and_graphs( + model, mesh, tracing_input_fn, graph_passes=["split_fsdp_collectives"] + ) + sharded_params, buffers, x = _get_fw_inputs(pp_mod, eval_input_fn) + with ( + FakeTensorMode( + allow_non_fake_inputs=True, + shape_env=ShapeEnv(), + ) + if fake_evaluate + else nullcontext() + ): + # # now let's run it + with torch.no_grad(): + unshard_args = list(sharded_params) + assert graph_modules.unshard is not None + unsharded_params = _run_unshard_module( + graph_modules.unshard, graph_meta, unshard_args + ) + fw_args = [*unsharded_params, *buffers, *x] + output, saved_intermediates = _run_fw_module( + graph_modules.fw, graph_meta, fw_args + ) + tangents = [torch.randn_like(output)] + tensors_for_backward, non_tensors_for_backward = saved_intermediates + + bw_args = [ + *non_tensors_for_backward, + *tensors_for_backward, + *tangents, + ] + del ( + tensors_for_backward, + non_tensors_for_backward, + tangents, + saved_intermediates, + ) + input_grads, unsharded_param_buffer_grads = _run_full_bw_module( + graph_modules.full_bw, graph_meta, bw_args + ) + unsharded_grads = list(unsharded_param_buffer_grads[: len(sharded_params)]) + del unsharded_param_buffer_grads, input_grads + assert graph_modules.reduce_grad is not None + sharded_grads = _run_reduce_grad_module( + graph_modules.reduce_grad, graph_meta, unsharded_grads + ) + assert len(sharded_grads) == len(sharded_params) + + print("All good!") + + +def test_split_dI_dW( + model: torch.nn.Module, + mesh: DeviceMesh, + tracing_input_fn: Callable, + eval_input_fn: Callable, + fake_evaluate: bool = True, +): + + pp_mod, graph_modules, graph_meta = _get_pp_module_and_graphs( + model, mesh, tracing_input_fn, graph_passes=["split_dI_dW"] + ) + sharded_params, buffers, x = _get_fw_inputs(pp_mod, eval_input_fn) + with ( + FakeTensorMode( + allow_non_fake_inputs=True, + shape_env=ShapeEnv(), + ) + if fake_evaluate + else nullcontext() + ): + # # now let's run it + with torch.no_grad(): + fw_args = [*sharded_params, *buffers, *x] + output, saved_intermediates = _run_fw_module( + graph_modules.fw, graph_meta, fw_args + ) + tangents = [torch.randn_like(output)] + tensors_for_backward, non_tensors_for_backward = saved_intermediates + + bw_args = [ + *non_tensors_for_backward, + *tensors_for_backward, + *tangents, + ] + del ( + tensors_for_backward, + non_tensors_for_backward, + tangents, + saved_intermediates, + ) + assert graph_modules.bw_dI is not None + input_grads, activations_for_backward = _run_dI_bw_module( + graph_modules.bw_dI, graph_meta, bw_args + ) + dw_args = list(activations_for_backward) + del activations_for_backward + assert graph_modules.bw_dW is not None + sharded_param_buffer_grads = _run_dW_bw_module( + graph_modules.bw_dW, graph_meta, dw_args + ) + assert len(sharded_param_buffer_grads) == ( + len(sharded_params) + len(buffers) + ) + + print("All good!") + + +def test_combined( + model: torch.nn.Module, + mesh: DeviceMesh, + tracing_input_fn: Callable, + eval_input_fn: Callable, + fake_evaluate: bool = True, +): + + pp_mod, graph_modules, graph_meta = _get_pp_module_and_graphs( + model, + mesh, + tracing_input_fn, + graph_passes=["split_fsdp_collectives", "split_dI_dW"], + ) + sharded_params, buffers, x = _get_fw_inputs(pp_mod, eval_input_fn) + with ( + FakeTensorMode( + allow_non_fake_inputs=True, + shape_env=ShapeEnv(), + ) + if fake_evaluate + else nullcontext() + ): + # # now let's run it + with torch.no_grad(): + unshard_args = list(sharded_params) + assert graph_modules.unshard is not None + unsharded_params = _run_unshard_module( + graph_modules.unshard, graph_meta, unshard_args + ) + fw_args = [*unsharded_params, *buffers, *x] + output, saved_intermediates = _run_fw_module( + graph_modules.fw, graph_meta, fw_args + ) + tangents = [torch.randn_like(output)] + tensors_for_backward, non_tensors_for_backward = saved_intermediates + + bw_args = [ + *non_tensors_for_backward, + *tensors_for_backward, + *tangents, + ] + del ( + tensors_for_backward, + non_tensors_for_backward, + tangents, + saved_intermediates, + ) + assert graph_modules.bw_dI is not None + input_grads, activations_for_backward = _run_dI_bw_module( + graph_modules.bw_dI, graph_meta, bw_args + ) + dw_args = list(activations_for_backward) + del activations_for_backward + assert graph_modules.bw_dW is not None + unsharded_param_buffer_grads = _run_dW_bw_module( + graph_modules.bw_dW, graph_meta, dw_args + ) + unsharded_grads = list(unsharded_param_buffer_grads[: len(sharded_params)]) + del unsharded_param_buffer_grads, input_grads + assert graph_modules.reduce_grad is not None + sharded_grads = _run_reduce_grad_module( + graph_modules.reduce_grad, graph_meta, unsharded_grads + ) + assert len(sharded_grads) == len(sharded_params) + + print("All good!") + + +if __name__ == "__main__": + # must symbolically evaluate to run on 32 dp ranks + # world_size = 2048 + fake_evaluate = True + + world_size = 256 + + fake_store = FakeStore() + torch.distributed.init_process_group( + "fake", store=fake_store, rank=0, world_size=world_size + ) + # mesh = torch.distributed.device_mesh.init_device_mesh("cuda", (world_size,), mesh_dim_names=("dp",)) + mesh = torch.distributed.device_mesh.init_device_mesh( + "cuda", + (world_size // 64, 64), + mesh_dim_names=( + "dp", + "ep", + ), + ) + + device = torch.device("cuda") + + bs = 4 * mesh.shape[0] * mesh.shape[1] + seq_len = 1024 + + config = DeepSeekV3ModelArgs( + vocab_size=102400, + max_seq_len=seq_len, + dim=2048, + inter_dim=10944, + moe_inter_dim=1408, + n_layers=1, # 27, + n_dense_layers=0, # 1, + n_heads=16, + moe_args=MoEArgs( + num_experts=64, + num_shared_experts=2, + top_k=6, + score_func="softmax", + route_norm=False, + score_before_experts=False, + mesh=mesh, + ), + q_lora_rank=0, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + mscale=0.70, + use_flex_attn=False, + attn_mask_type="causal", + ) + + # parallelize the model + with torch.device("meta"): + model = DeepSeekV3Model(config).bfloat16() + model.tok_embeddings = None # type: ignore[assignment] + + def tracing_input_fn() -> torch.Tensor: + return torch.randn( + (bs, seq_len, config.dim), + device=device, + dtype=torch.bfloat16, + requires_grad=True, + ) + + def eval_input_fn() -> torch.Tensor: + return torch.randn( + (bs // mesh.shape[0] // mesh.shape[1], seq_len, config.dim), + device=device, + dtype=torch.bfloat16, + requires_grad=True, + ) + + test_graph_partition(model, mesh, tracing_input_fn, eval_input_fn, fake_evaluate) + test_split_fsdp_collectives( + model, mesh, tracing_input_fn, eval_input_fn, fake_evaluate + ) + test_split_dI_dW(model, mesh, tracing_input_fn, eval_input_fn, fake_evaluate) + test_combined(model, mesh, tracing_input_fn, eval_input_fn, fake_evaluate) + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group()