diff --git a/autoparallel/_testing/models/dsv3.py b/autoparallel/_testing/models/dsv3.py index 18eb897..f9293ce 100644 --- a/autoparallel/_testing/models/dsv3.py +++ b/autoparallel/_testing/models/dsv3.py @@ -1062,7 +1062,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # HOPs don't support buffer mutations, keep this outside with torch.no_grad(): - self.tokens_per_expert.add_(num_tokens_per_expert) + self.tokens_per_expert.add_(num_tokens_per_expert) # type: ignore[operator] return out def init_weights( @@ -1076,14 +1076,10 @@ def init_weights( self.shared_experts.init_weights(init_std) with torch.device(buffer_device): - self.tokens_per_expert = torch.zeros( - self.experts.num_experts, dtype=torch.float32 - ) + self.tokens_per_expert.zero_() # type: ignore[operator] if self.load_balance_coeff is not None: assert isinstance(self.expert_bias, torch.Tensor) - self.expert_bias = torch.zeros( - self.experts.num_experts, dtype=torch.float32 - ) + self.expert_bias.zero_() # type: ignore[operator] def has_cuda_capability(major: int, minor: int) -> bool: diff --git a/autoparallel/graph_pp_runner.py b/autoparallel/graph_pp_runner.py index d082032..1c9a748 100644 --- a/autoparallel/graph_pp_runner.py +++ b/autoparallel/graph_pp_runner.py @@ -5,7 +5,7 @@ import logging from dataclasses import dataclass -from typing import Any, Callable, cast, Optional, Union +from typing import Any, Callable, Optional, Union, cast import torch import torch.fx as fx @@ -16,11 +16,13 @@ _wait_batch_p2p, ) from torch.distributed.pipelining.stage import ( - _normalize_model_output_as_tuple, PipelineStage, + _normalize_model_output_as_tuple, ) from torch.distributed.tensor import DTensor +from autoparallel.utils import DebugInterpreter + logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -97,12 +99,21 @@ def scale_grads(self, grad_scale_factor: int) -> None: def _run_fw_module( - fw_module: fx.GraphModule, graph_meta: GraphMeta, fw_args: list[Any] + fw_module: fx.GraphModule, + graph_meta: GraphMeta, + fw_args: list[Any], + numerics_logs: Optional[list[str]] = None, ) -> tuple[Any, tuple[list[Any], list[Any]]]: assert len([n for n in fw_module.graph.nodes if n.op == "placeholder"]) == len( fw_args ), f"Mismatched number of inputs to fwd, {len([n for n in fw_module.graph.nodes if n.op == 'placeholder'])}, {len(fw_args)}" - fw_outputs = torch.fx.Interpreter(fw_module).boxed_run(fw_args) + if numerics_logs is not None: + debug_interpreter = DebugInterpreter(fw_module) + fw_outputs = debug_interpreter.boxed_run(fw_args) + numerics_logs += debug_interpreter.get_logs() + else: + fw_outputs = torch.fx.Interpreter(fw_module).boxed_run(fw_args) + num_inner_fwd_outputs = graph_meta.num_mutate_inputs + graph_meta.num_user_outputs saved_intermediates = fw_outputs[num_inner_fwd_outputs:] num_tensors_for_backward = ( @@ -173,14 +184,16 @@ def _run_reduce_grad_module( return sharded_grads -def _run_forward_microbatch(stage: GraphPipelineStage, *args) -> tuple[Any, Any]: +def _run_forward_microbatch( + stage: GraphPipelineStage, *args, numerics_logs: Optional[list[str]] = None +) -> tuple[Any, Any]: fw_args = [ *stage.state["unsharded_params"], *stage.state["buffers"], *args, ] user_outputs, saved_intermediates = _run_fw_module( - stage.graph_callables.fw, stage.graph_meta, fw_args + stage.graph_callables.fw, stage.graph_meta, fw_args, numerics_logs=numerics_logs ) return (user_outputs, saved_intermediates) @@ -220,6 +233,7 @@ def _run_backward_microbatch( def stage_forward( action: _Action, ctx: _PipelineContext, + numerics_logs: Optional[list[str]] = None, ) -> None: schedule = ctx.schedule_ref assert isinstance(schedule, _PipelineScheduleRuntime) @@ -267,7 +281,9 @@ def stage_forward( "GraphPPRunner running action %s", action, ) - output, saved_intermediates = _run_forward_microbatch(stage, *composite_args) + output, saved_intermediates = _run_forward_microbatch( + stage, *composite_args, numerics_logs=numerics_logs + ) # See [Note: pipeline model output type] output_tuple = _normalize_model_output_as_tuple(output) diff --git a/autoparallel/utils.py b/autoparallel/utils.py index d84a7c9..2f77739 100644 --- a/autoparallel/utils.py +++ b/autoparallel/utils.py @@ -3,7 +3,10 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. +from typing import Any, Iterable + import torch +import torch.utils._pytree as pytree from torch.distributed._tensor.placement_types import Placement, TensorMeta from torch.distributed.device_mesh import _get_device_handle from torch.distributed.tensor._dtensor_spec import DTensorSpec @@ -310,3 +313,67 @@ def _get_device_from_mesh(mesh): return torch.device("cpu") device_handle = _get_device_handle(mesh.device_type) return torch.device(mesh.device_type, device_handle.current_device()) + + +# An FX graph interpreter that logs inputs and outputs of each node +# with a few exceptions for c10d ops +class DebugInterpreter(torch.fx.Interpreter): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._logs = [] + + def log(self, node: str, args: Iterable[Any], inputs_or_outputs: str): + leaves, _ = pytree.tree_flatten(args) + for i, arg in enumerate(leaves): + if not isinstance(arg, torch.Tensor): + self._logs.append(f"{node=}, {inputs_or_outputs}[{i}]={arg}") + continue + + if arg.numel() == 0: + self._logs.append(f"{node=}, {inputs_or_outputs}[{i}].numel()=0") + continue + + if arg.is_complex(): + real = torch.hash_tensor(arg.real) + imag = torch.hash_tensor(arg.imag) + self._logs.append(f"{node=}, {inputs_or_outputs}[{i}], {real=} {imag=}") + continue + + self._logs.append( + f"{node=}, {inputs_or_outputs}[{i}]={torch.hash_tensor(arg)}" + ) + + def run_node(self, n: torch.fx.Node) -> Any: + args, kwargs = self.fetch_args_kwargs_from_env(n) + + # reading wait_tensor inputs is undefined behavior + if "wait_tensor" not in n.name: + args, _ = self.fetch_args_kwargs_from_env(n) + self.log(n.name, args, "args") + + out = super().run_node(n) + + # reading functional collectives outputs before wait_tensor is undefined behavior + if "c10d" not in str(n.target): + outs = out + if isinstance(outs, torch.Tensor): + outs = [outs] + self.log(n.name, outs, "outs") + + return out + + def get_logs(self): + return self._logs + + +# Always prints from rank 0 to rank N +def print_rank_by_rank(msg: Any): + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + torch.distributed.barrier() + for i in range(world_size): + if rank == i: + print(f"{rank=} start") + print(msg) + print(f"{rank=} done") + torch.distributed.barrier() diff --git a/examples/example_ds3_pp.py b/examples/example_ds3_pp.py index 556e583..2f148fb 100644 --- a/examples/example_ds3_pp.py +++ b/examples/example_ds3_pp.py @@ -3,10 +3,11 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. +import functools import logging import os from contextlib import nullcontext -from typing import Callable +from typing import Callable, Optional import torch import torch.distributed._tools.fake_collectives @@ -49,6 +50,7 @@ stage_reshard, stage_unshard, ) +from autoparallel.utils import print_rank_by_rank # Configure logging to show DEBUG messages logging.basicConfig( @@ -98,7 +100,7 @@ def build_pipeline_schedule( return schedule -def run_test(fake_evaluate: bool = True): +def run_test(fake_evaluate: bool, debug_numerics: Optional[bool]): if not fake_evaluate: pp_degree = 2 dp_mod_ep_degree = 2 @@ -346,7 +348,9 @@ def shape_inference_output_fn_last_stage(): input_fn = tracing_input_fn else: input_fn = tracing_input_fn_after_first_stage - with AutoParallelPP(stage_mod, input_fn, mesh, dynamic=True) as autop: + with AutoParallelPP( + stage_mod, input_fn, mesh, dynamic=True, compile=False + ) as autop: autop.add_parameter_memory_constraint(low=None, high=None) # x_sharding = (Shard(0), Replicate()) @@ -367,7 +371,6 @@ def shape_inference_output_fn_last_stage(): if use_cache: torch.save(cache, stage_file) - torch.manual_seed(pp_rank) pp_mod.to_empty(device=device) pp_mod.init_weights(buffer_device=device) @@ -443,7 +446,10 @@ def shape_inference_output_fn_last_stage(): ) assert isinstance(schedule, _PipelineScheduleRuntime) # Step 6. Override the pipeline runner's action implementations - schedule.register_custom_function(FORWARD, stage_forward) + numerics_logs = [] + schedule.register_custom_function( + FORWARD, functools.partial(stage_forward, numerics_logs=numerics_logs) + ) schedule.register_custom_function(FULL_BACKWARD, stage_full_backward) schedule.register_custom_function(REDUCE_GRAD, stage_reduce_grad) schedule.register_custom_function(RESHARD, stage_reshard) @@ -469,6 +475,9 @@ def shape_inference_output_fn_last_stage(): else: graph_pp_runner.step() + if debug_numerics: + print_rank_by_rank("\n".join(numerics_logs)) + print("All good!") if torch.distributed.is_initialized(): @@ -489,6 +498,16 @@ def shape_inference_output_fn_last_stage(): default=False, help="Use fake evaluation mode with FakeTensorMode (default: False)", ) + parser.add_argument( + "--rng-seed", + type=int, + default=None, + help="Use a specific rng seed and deterministic algorithms for run-to-run invariance (default: None).", + ) args = parser.parse_args() - run_test(fake_evaluate=args.fake_evaluate) + if args.rng_seed is not None: + torch.use_deterministic_algorithms(True) + torch.manual_seed(args.rng_seed) + + run_test(fake_evaluate=args.fake_evaluate, debug_numerics=args.rng_seed is not None)