From 37019196ed168bf4984ae46af075be1e3dda1be3 Mon Sep 17 00:00:00 2001 From: sanketpurandare Date: Fri, 7 Nov 2025 13:31:39 -0800 Subject: [PATCH] Fixing backward not being called, gradient scaling and enabling backward with torch.no_grad() (#237) stack-info: PR: https://github.com/meta-pytorch/autoparallel/pull/239, branch: xmfan/stack/17 --- autoparallel/_testing/models/dsv3.py | 10 +- autoparallel/graph_pp_runner.py | 116 +++++++++++--- autoparallel/utils.py | 67 ++++++++ examples/example_ds3_pp.py | 47 +++++- examples/example_pp_graph_partition.py | 206 ------------------------- 5 files changed, 205 insertions(+), 241 deletions(-) delete mode 100644 examples/example_pp_graph_partition.py diff --git a/autoparallel/_testing/models/dsv3.py b/autoparallel/_testing/models/dsv3.py index 18eb897..f9293ce 100644 --- a/autoparallel/_testing/models/dsv3.py +++ b/autoparallel/_testing/models/dsv3.py @@ -1062,7 +1062,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # HOPs don't support buffer mutations, keep this outside with torch.no_grad(): - self.tokens_per_expert.add_(num_tokens_per_expert) + self.tokens_per_expert.add_(num_tokens_per_expert) # type: ignore[operator] return out def init_weights( @@ -1076,14 +1076,10 @@ def init_weights( self.shared_experts.init_weights(init_std) with torch.device(buffer_device): - self.tokens_per_expert = torch.zeros( - self.experts.num_experts, dtype=torch.float32 - ) + self.tokens_per_expert.zero_() # type: ignore[operator] if self.load_balance_coeff is not None: assert isinstance(self.expert_bias, torch.Tensor) - self.expert_bias = torch.zeros( - self.experts.num_experts, dtype=torch.float32 - ) + self.expert_bias.zero_() # type: ignore[operator] def has_cuda_capability(major: int, minor: int) -> bool: diff --git a/autoparallel/graph_pp_runner.py b/autoparallel/graph_pp_runner.py index 6d2992c..1c9a748 100644 --- a/autoparallel/graph_pp_runner.py +++ b/autoparallel/graph_pp_runner.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. +import logging from dataclasses import dataclass from typing import Any, Callable, Optional, Union, cast @@ -20,6 +21,11 @@ ) from torch.distributed.tensor import DTensor +from autoparallel.utils import DebugInterpreter + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + @dataclass class GraphCallables: @@ -75,14 +81,39 @@ def __init__( "unsharded_grads": [], } + def scale_grads(self, grad_scale_factor: int) -> None: + """Scale stage's gradients by `grad_scale_factor`, which should be specified in coordination with the + loss function used with pipelining. For loss functions which perform 'mean' loss reduction, `grad_scale_factor` + should be set to num_microbatches. For loss functions that use `sum` reduction, `grad_scale_factor` should + be set to 1. + + Should only be called once per pipeline schedule step, after all backwards passes have completed. + """ + + # PP scales only for its own contribution (microbatches), but relies on DP to scale further + # for DP degree. + if grad_scale_factor != 1: + for grad in self.state["unsharded_grads"]: + if grad is not None: + grad.div_(grad_scale_factor) + def _run_fw_module( - fw_module: fx.GraphModule, graph_meta: GraphMeta, fw_args: list[Any] + fw_module: fx.GraphModule, + graph_meta: GraphMeta, + fw_args: list[Any], + numerics_logs: Optional[list[str]] = None, ) -> tuple[Any, tuple[list[Any], list[Any]]]: assert len([n for n in fw_module.graph.nodes if n.op == "placeholder"]) == len( fw_args ), f"Mismatched number of inputs to fwd, {len([n for n in fw_module.graph.nodes if n.op == 'placeholder'])}, {len(fw_args)}" - fw_outputs = torch.fx.Interpreter(fw_module).boxed_run(fw_args) + if numerics_logs is not None: + debug_interpreter = DebugInterpreter(fw_module) + fw_outputs = debug_interpreter.boxed_run(fw_args) + numerics_logs += debug_interpreter.get_logs() + else: + fw_outputs = torch.fx.Interpreter(fw_module).boxed_run(fw_args) + num_inner_fwd_outputs = graph_meta.num_mutate_inputs + graph_meta.num_user_outputs saved_intermediates = fw_outputs[num_inner_fwd_outputs:] num_tensors_for_backward = ( @@ -153,14 +184,16 @@ def _run_reduce_grad_module( return sharded_grads -def _run_forward_microbatch(stage: GraphPipelineStage, *args) -> tuple[Any, Any]: +def _run_forward_microbatch( + stage: GraphPipelineStage, *args, numerics_logs: Optional[list[str]] = None +) -> tuple[Any, Any]: fw_args = [ *stage.state["unsharded_params"], *stage.state["buffers"], *args, ] user_outputs, saved_intermediates = _run_fw_module( - stage.graph_callables.fw, stage.graph_meta, fw_args + stage.graph_callables.fw, stage.graph_meta, fw_args, numerics_logs=numerics_logs ) return (user_outputs, saved_intermediates) @@ -200,6 +233,7 @@ def _run_backward_microbatch( def stage_forward( action: _Action, ctx: _PipelineContext, + numerics_logs: Optional[list[str]] = None, ) -> None: schedule = ctx.schedule_ref assert isinstance(schedule, _PipelineScheduleRuntime) @@ -243,8 +277,13 @@ def stage_forward( composite_args = stage._retrieve_recv_activations(mb_index) # stage._validate_fwd_input(args, kwargs) Maybe need to validate composite args? - - output, saved_intermediates = _run_forward_microbatch(stage, *composite_args) + logger.debug( + "GraphPPRunner running action %s", + action, + ) + output, saved_intermediates = _run_forward_microbatch( + stage, *composite_args, numerics_logs=numerics_logs + ) # See [Note: pipeline model output type] output_tuple = _normalize_model_output_as_tuple(output) @@ -306,6 +345,7 @@ def stage_full_backward( grad_scale_factor = schedule._n_microbatches if schedule.scale_grads else 1 if not backward_stage.has_backward: + logger.debug("Returning early for backward stage") return ( stage_output, @@ -320,7 +360,7 @@ def stage_full_backward( # HACK till we have loss function, we populate the tangents here manually bwd_kwargs = { "stage_output": loss, - "tangents": [torch.randn_like(stage_output)], + "tangents": [torch.randn_like(stage_output[0])], "saved_intermediates": saved_intermediates, } else: @@ -334,10 +374,14 @@ def stage_full_backward( "tangents": output_grads, "saved_intermediates": saved_intermediates, } - + logger.debug( + "GraphPPRunner running action %s", + action, + ) input_grads = _run_backward_microbatch(backward_stage, bwd_kwargs) - - backward_stage.bwd_cache[backward_mb_index] = input_grads + backward_stage.bwd_cache[backward_mb_index] = ( + tuple(input_grads) if not isinstance(input_grads, tuple) else input_grads + ) # skipping detach logic @@ -362,9 +406,20 @@ def stage_unshard( stage.stage_index: cast(GraphPipelineStage, stage) for stage in schedule._stages } stage = stage_index_to_stage[action.stage_index] + logger.debug( + "GraphPPRunner running action %s", + action, + ) if stage.graph_callables.unshard is None: stage.state["unsharded_params"] = stage.state["sharded_params"] - # TODO (sanketpurandare): Add the fw_fsdp_all_gather graph call here + else: + sharded_params = list(stage.state["sharded_params"]) + unsharded_params = _run_unshard_module( + stage.graph_callables.unshard, + stage.graph_meta, + sharded_params, + ) + stage.state["unsharded_params"] = unsharded_params def stage_reshard( @@ -377,6 +432,10 @@ def stage_reshard( stage.stage_index: cast(GraphPipelineStage, stage) for stage in schedule._stages } stage = stage_index_to_stage[action.stage_index] + logger.debug( + "GraphPPRunner running action %s", + action, + ) stage.state["unsharded_params"].clear() @@ -390,8 +449,19 @@ def stage_reduce_grad( stage.stage_index: cast(GraphPipelineStage, stage) for stage in schedule._stages } stage = stage_index_to_stage[action.stage_index] + logger.debug( + "GraphPPRunner running action %s", + action, + ) if stage.graph_callables.reduce_grad is None: stage.state["sharded_grads"] = stage.state["unsharded_grads"] + else: + sharded_grads = _run_reduce_grad_module( + stage.graph_callables.reduce_grad, + stage.graph_meta, + stage.state["unsharded_grads"], + ) + stage.state["sharded_grads"] = sharded_grads class GraphPPRunner: @@ -400,6 +470,19 @@ def __init__( schedule: _PipelineScheduleRuntime, ): self.schedule = schedule + if not schedule._backward_requires_autograd: + assert all( + isinstance(stage, GraphPipelineStage) + and ( + stage.graph_callables.full_bw is not None + or ( + stage.graph_callables.bw_dI is not None + and stage.graph_callables.bw_dW is not None + ) + ) + for stage in schedule._stages + ) + self.schedule._has_backward = True def _populate_stage_states(self, stage: GraphPipelineStage) -> None: sharded_params = [ @@ -415,21 +498,10 @@ def _populate_stage_states(self, stage: GraphPipelineStage) -> None: stage.state["sharded_params"] = sharded_params stage.state["buffers"] = buffers stage.state["unsharded_grads"] = [None] * len(sharded_params) - # TODO (sanketpurandare) - # pipeline schedule runtime does not allow us to register a custom function - # for UNSHARD/RESHARD/REDUCE_GRAD action types yet - # HACK remove this once we support this - if stage.graph_callables.unshard is None: - stage.state["unsharded_params"] = stage.state["sharded_params"] def _accumulate_stage_grads_and_clear_states( self, stage: GraphPipelineStage ) -> None: - # TODO (sanketpurandare) - # We don't have a REDUCE_GRAD action yet in the ScheduleIR yet - # HACK remove this once Ivan's PR lands - if stage.graph_callables.reduce_grad is None: - stage.state["sharded_grads"] = stage.state["unsharded_grads"] grads = stage.state["sharded_grads"] params = list(stage.submod.parameters()) for param, grad in zip(params, grads): diff --git a/autoparallel/utils.py b/autoparallel/utils.py index d84a7c9..2f77739 100644 --- a/autoparallel/utils.py +++ b/autoparallel/utils.py @@ -3,7 +3,10 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. +from typing import Any, Iterable + import torch +import torch.utils._pytree as pytree from torch.distributed._tensor.placement_types import Placement, TensorMeta from torch.distributed.device_mesh import _get_device_handle from torch.distributed.tensor._dtensor_spec import DTensorSpec @@ -310,3 +313,67 @@ def _get_device_from_mesh(mesh): return torch.device("cpu") device_handle = _get_device_handle(mesh.device_type) return torch.device(mesh.device_type, device_handle.current_device()) + + +# An FX graph interpreter that logs inputs and outputs of each node +# with a few exceptions for c10d ops +class DebugInterpreter(torch.fx.Interpreter): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._logs = [] + + def log(self, node: str, args: Iterable[Any], inputs_or_outputs: str): + leaves, _ = pytree.tree_flatten(args) + for i, arg in enumerate(leaves): + if not isinstance(arg, torch.Tensor): + self._logs.append(f"{node=}, {inputs_or_outputs}[{i}]={arg}") + continue + + if arg.numel() == 0: + self._logs.append(f"{node=}, {inputs_or_outputs}[{i}].numel()=0") + continue + + if arg.is_complex(): + real = torch.hash_tensor(arg.real) + imag = torch.hash_tensor(arg.imag) + self._logs.append(f"{node=}, {inputs_or_outputs}[{i}], {real=} {imag=}") + continue + + self._logs.append( + f"{node=}, {inputs_or_outputs}[{i}]={torch.hash_tensor(arg)}" + ) + + def run_node(self, n: torch.fx.Node) -> Any: + args, kwargs = self.fetch_args_kwargs_from_env(n) + + # reading wait_tensor inputs is undefined behavior + if "wait_tensor" not in n.name: + args, _ = self.fetch_args_kwargs_from_env(n) + self.log(n.name, args, "args") + + out = super().run_node(n) + + # reading functional collectives outputs before wait_tensor is undefined behavior + if "c10d" not in str(n.target): + outs = out + if isinstance(outs, torch.Tensor): + outs = [outs] + self.log(n.name, outs, "outs") + + return out + + def get_logs(self): + return self._logs + + +# Always prints from rank 0 to rank N +def print_rank_by_rank(msg: Any): + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + torch.distributed.barrier() + for i in range(world_size): + if rank == i: + print(f"{rank=} start") + print(msg) + print(f"{rank=} done") + torch.distributed.barrier() diff --git a/examples/example_ds3_pp.py b/examples/example_ds3_pp.py index 02c8a86..2f148fb 100644 --- a/examples/example_ds3_pp.py +++ b/examples/example_ds3_pp.py @@ -3,10 +3,11 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. +import functools import logging import os from contextlib import nullcontext -from typing import Callable +from typing import Callable, Optional import torch import torch.distributed._tools.fake_collectives @@ -16,6 +17,9 @@ from torch.distributed.pipelining.schedules import ( FORWARD, FULL_BACKWARD, + REDUCE_GRAD, + RESHARD, + UNSHARD, PipelineScheduleMulti, _PipelineSchedule, _PipelineScheduleRuntime, @@ -42,8 +46,16 @@ GraphPPRunner, stage_forward, stage_full_backward, + stage_reduce_grad, + stage_reshard, + stage_unshard, ) +from autoparallel.utils import print_rank_by_rank +# Configure logging to show DEBUG messages +logging.basicConfig( + level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) logger = logging.getLogger(__name__) @@ -54,6 +66,7 @@ def build_pipeline_schedule( microbatch_size: int, local_batch_size: int, pipeline_parallel_degree: int, + backward_requires_autograd: bool = False, ) -> _PipelineSchedule: """Builds a pipeline schedule for the given configuration and stages.""" schedule_class = get_schedule_class(pipeline_parallel_schedule) @@ -78,6 +91,7 @@ def build_pipeline_schedule( stages if looped_schedule else stages[0], n_microbatches=n_microbatches, loss_fn=loss_fn, + backward_requires_autograd=backward_requires_autograd, ) logger.info( f"Using pipeline schedule {pipeline_parallel_schedule} " @@ -86,7 +100,7 @@ def build_pipeline_schedule( return schedule -def run_test(fake_evaluate: bool = True): +def run_test(fake_evaluate: bool, debug_numerics: Optional[bool]): if not fake_evaluate: pp_degree = 2 dp_mod_ep_degree = 2 @@ -334,7 +348,9 @@ def shape_inference_output_fn_last_stage(): input_fn = tracing_input_fn else: input_fn = tracing_input_fn_after_first_stage - with AutoParallelPP(stage_mod, input_fn, mesh, dynamic=True) as autop: + with AutoParallelPP( + stage_mod, input_fn, mesh, dynamic=True, compile=False + ) as autop: autop.add_parameter_memory_constraint(low=None, high=None) # x_sharding = (Shard(0), Replicate()) @@ -355,7 +371,6 @@ def shape_inference_output_fn_last_stage(): if use_cache: torch.save(cache, stage_file) - torch.manual_seed(pp_rank) pp_mod.to_empty(device=device) pp_mod.init_weights(buffer_device=device) @@ -427,11 +442,18 @@ def shape_inference_output_fn_last_stage(): microbatch_size=microbatch_size, local_batch_size=local_batch_size, pipeline_parallel_degree=pp_degree, + backward_requires_autograd=False, ) assert isinstance(schedule, _PipelineScheduleRuntime) # Step 6. Override the pipeline runner's action implementations - schedule.register_custom_function(FORWARD, stage_forward) + numerics_logs = [] + schedule.register_custom_function( + FORWARD, functools.partial(stage_forward, numerics_logs=numerics_logs) + ) schedule.register_custom_function(FULL_BACKWARD, stage_full_backward) + schedule.register_custom_function(REDUCE_GRAD, stage_reduce_grad) + schedule.register_custom_function(RESHARD, stage_reshard) + schedule.register_custom_function(UNSHARD, stage_unshard) # Step 7. Register the schedule with the graph runner @@ -453,6 +475,9 @@ def shape_inference_output_fn_last_stage(): else: graph_pp_runner.step() + if debug_numerics: + print_rank_by_rank("\n".join(numerics_logs)) + print("All good!") if torch.distributed.is_initialized(): @@ -473,6 +498,16 @@ def shape_inference_output_fn_last_stage(): default=False, help="Use fake evaluation mode with FakeTensorMode (default: False)", ) + parser.add_argument( + "--rng-seed", + type=int, + default=None, + help="Use a specific rng seed and deterministic algorithms for run-to-run invariance (default: None).", + ) args = parser.parse_args() - run_test(fake_evaluate=args.fake_evaluate) + if args.rng_seed is not None: + torch.use_deterministic_algorithms(True) + torch.manual_seed(args.rng_seed) + + run_test(fake_evaluate=args.fake_evaluate, debug_numerics=args.rng_seed is not None) diff --git a/examples/example_pp_graph_partition.py b/examples/example_pp_graph_partition.py deleted file mode 100644 index 77bd2e9..0000000 --- a/examples/example_pp_graph_partition.py +++ /dev/null @@ -1,206 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - -import itertools -from contextlib import nullcontext - -import torch -from torch._subclasses.fake_tensor import FakeTensorMode -from torch.distributed.tensor.placement_types import Shard -from torch.fx.experimental.symbolic_shapes import ShapeEnv -from torch.testing._internal.distributed.fake_pg import FakeStore - -from autoparallel._passes.split_fsdp_collectives import ( - split_fsdp_prefetch, - split_fsdp_reduce_scatters_epilogue, -) -from autoparallel._testing.models.dsv3 import ( - DeepSeekV3Model, - DeepSeekV3ModelArgs, - MoEArgs, -) -from autoparallel.api import AutoParallelPP -from autoparallel.graph_pp_runner import GraphMeta, _run_fw_module, _run_split_bw_module - -# must symbolically evaluate to run on 32 dp ranks -# world_size = 2048 -fake_evaluate = True - -world_size = 256 - -fake_store = FakeStore() -torch.distributed.init_process_group( - "fake", store=fake_store, rank=0, world_size=world_size -) -# mesh = torch.distributed.device_mesh.init_device_mesh("cuda", (world_size,), mesh_dim_names=("dp",)) -mesh = torch.distributed.device_mesh.init_device_mesh( - "cuda", - (world_size // 64, 64), - mesh_dim_names=( - "dp", - "ep", - ), -) - -device = torch.device("cuda") - -bs = 4 * mesh.shape[0] * mesh.shape[1] -seq_len = 1024 -dim = 2048 - -config = DeepSeekV3ModelArgs( - vocab_size=102400, - max_seq_len=seq_len, - dim=dim, - inter_dim=10944, - moe_inter_dim=1408, - n_layers=1, # 27, - n_dense_layers=0, # 1, - n_heads=16, - moe_args=MoEArgs( - num_experts=64, - num_shared_experts=2, - top_k=6, - score_func="softmax", - route_norm=False, - score_before_experts=False, - mesh=mesh, - ), - q_lora_rank=0, - kv_lora_rank=512, - qk_nope_head_dim=128, - qk_rope_head_dim=64, - v_head_dim=128, - mscale=0.70, - use_flex_attn=False, - attn_mask_type="causal", -) - -# parallelize the model -with torch.device("meta"): - model = DeepSeekV3Model(config).bfloat16() - model.tok_embeddings = None - - -# Removing tok_embeddings from the model and passing in a float input that requires_grad, -# so we can run the dI/dW partitioning pass -def input_fn(): - return torch.randn( - (bs, seq_len, dim), device=device, dtype=torch.bfloat16, requires_grad=True - ) - - -# def input_fn(): -# return torch.randint( -# 0, -# config.vocab_size, -# (bs, seq_len), -# device=device, -# ) - - -with AutoParallelPP(model, input_fn, mesh, dynamic=True) as autop: - autop.add_parameter_memory_constraint(low=None, high=None) - - # x_sharding = (Shard(0), Replicate()) - x_sharding = (Shard(0), Shard(0)) - - autop.add_input_constraints([x_sharding]) - autop.add_output_constraints([x_sharding]) - - sharding_placement = autop.optimize_placement() - res = autop.apply_placement_pp(sharding_placement, generate_di_dw_split_graphs=True) - graph_callables = res["graph_callables"] - graph_meta = res["graph_meta"] - graph_meta = GraphMeta( - num_mutate_inputs=graph_meta["num_mutate_inputs"], - num_user_outputs=graph_meta["num_user_outputs"], - num_symints_saved_for_bw=graph_meta["num_symints_saved_for_bw"], - num_weight_buffer_grads=graph_meta["num_weight_buffer_grads"], - num_input_grads=graph_meta["num_input_grads"], - ) - pp_mod = autop.parallel_model - -pp_mod.to_empty(device="cuda") -# run weight init on our sharded DTensor params -# TODO: plumb init_std through -# pp_mod.init_weights( -# init_std=0.02, buffer_device="cuda" -# ) # maybe not correct value -pp_mod.init_weights(buffer_device="cuda") - -fw_g = graph_callables["fw"].graph -bw_g = graph_callables["full_bw"].graph -bw_dI_g = graph_callables["bw_dI"].graph -bw_dW_g = graph_callables["bw_dW"].graph - -fw_unshard_g, fw_main_g = split_fsdp_prefetch(fw_g) -bw_main_g, bw_reduce_grad_g = split_fsdp_reduce_scatters_epilogue(bw_g) - -# x = ( -# torch.randint( -# 0, -# config.vocab_size, -# (bs // mesh.shape[0] // mesh.shape[1], seq_len), -# device=torch.device("cuda"), -# ), -# ) -x = ( - torch.randn( - (bs // mesh.shape[0] // mesh.shape[1], seq_len, dim), - device=torch.device("cuda"), - dtype=torch.bfloat16, - requires_grad=True, - ), -) - -params_buffers = [ - v.to_local() - for k, v in - # TODO: this is very slow - itertools.chain( - dict(pp_mod.named_parameters(remove_duplicate=False)).items(), - dict(pp_mod.named_buffers(remove_duplicate=False)).items(), - ) -] -# Symbolically evaluate in case you want to test running a graph bigger than your gpu - -with ( - FakeTensorMode( - allow_non_fake_inputs=True, - shape_env=ShapeEnv(), - ) - if fake_evaluate - else nullcontext() -): - # # now let's run it - with torch.no_grad(): - fw_args = [*params_buffers, *x] - output, saved_intermediates = _run_fw_module( - graph_callables["fw"], graph_meta, fw_args - ) - tangents = [torch.randn_like(output)] - tensors_for_backward, non_tensors_for_backward = saved_intermediates - - bw_args = [ - *non_tensors_for_backward, - *tensors_for_backward, - *tangents, - ] - - # Full backward - # input_grads, param_buffer_grads = _run_full_bw_module( - # graph_callables["full_bw"], graph_meta, bw_args - # ) - # Split dI/dW backward - input_grads2, param_buffer_grads2 = _run_split_bw_module( - graph_callables["bw_dI"], graph_callables["bw_dW"], graph_meta, bw_args - ) - - -print("All good!") - -# Cleanup: destroy process group to allow other tests to initialize their own -torch.distributed.destroy_process_group()