Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions autoparallel/_testing/models/dsv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,7 +1062,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

# HOPs don't support buffer mutations, keep this outside
with torch.no_grad():
self.tokens_per_expert.add_(num_tokens_per_expert)
self.tokens_per_expert.add_(num_tokens_per_expert) # type: ignore[operator]
return out

def init_weights(
Expand All @@ -1076,14 +1076,10 @@ def init_weights(
self.shared_experts.init_weights(init_std)

with torch.device(buffer_device):
self.tokens_per_expert = torch.zeros(
self.experts.num_experts, dtype=torch.float32
)
self.tokens_per_expert.zero_() # type: ignore[operator]
if self.load_balance_coeff is not None:
assert isinstance(self.expert_bias, torch.Tensor)
self.expert_bias = torch.zeros(
self.experts.num_experts, dtype=torch.float32
)
self.expert_bias.zero_() # type: ignore[operator]


def has_cuda_capability(major: int, minor: int) -> bool:
Expand Down
30 changes: 23 additions & 7 deletions autoparallel/graph_pp_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import logging
from dataclasses import dataclass
from typing import Any, Callable, cast, Optional, Union
from typing import Any, Callable, Optional, Union, cast

import torch
import torch.fx as fx
Expand All @@ -16,11 +16,13 @@
_wait_batch_p2p,
)
from torch.distributed.pipelining.stage import (
_normalize_model_output_as_tuple,
PipelineStage,
_normalize_model_output_as_tuple,
)
from torch.distributed.tensor import DTensor

from autoparallel.utils import DebugInterpreter

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

Expand Down Expand Up @@ -97,12 +99,21 @@ def scale_grads(self, grad_scale_factor: int) -> None:


def _run_fw_module(
fw_module: fx.GraphModule, graph_meta: GraphMeta, fw_args: list[Any]
fw_module: fx.GraphModule,
graph_meta: GraphMeta,
fw_args: list[Any],
numerics_logs: Optional[list[str]] = None,
) -> tuple[Any, tuple[list[Any], list[Any]]]:
assert len([n for n in fw_module.graph.nodes if n.op == "placeholder"]) == len(
fw_args
), f"Mismatched number of inputs to fwd, {len([n for n in fw_module.graph.nodes if n.op == 'placeholder'])}, {len(fw_args)}"
fw_outputs = torch.fx.Interpreter(fw_module).boxed_run(fw_args)
if numerics_logs is not None:
debug_interpreter = DebugInterpreter(fw_module)
fw_outputs = debug_interpreter.boxed_run(fw_args)
numerics_logs += debug_interpreter.get_logs()
else:
fw_outputs = torch.fx.Interpreter(fw_module).boxed_run(fw_args)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should add this to all the graph module calls

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you mean the backward? I didn't add it since I couldn't test it on the base commit

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, once I land #237 we can add it for full_bw, bw_dI, bw_dW, unshard and reduce_grad.

num_inner_fwd_outputs = graph_meta.num_mutate_inputs + graph_meta.num_user_outputs
saved_intermediates = fw_outputs[num_inner_fwd_outputs:]
num_tensors_for_backward = (
Expand Down Expand Up @@ -173,14 +184,16 @@ def _run_reduce_grad_module(
return sharded_grads


def _run_forward_microbatch(stage: GraphPipelineStage, *args) -> tuple[Any, Any]:
def _run_forward_microbatch(
stage: GraphPipelineStage, *args, numerics_logs: Optional[list[str]] = None
) -> tuple[Any, Any]:
fw_args = [
*stage.state["unsharded_params"],
*stage.state["buffers"],
*args,
]
user_outputs, saved_intermediates = _run_fw_module(
stage.graph_callables.fw, stage.graph_meta, fw_args
stage.graph_callables.fw, stage.graph_meta, fw_args, numerics_logs=numerics_logs
)
return (user_outputs, saved_intermediates)

Expand Down Expand Up @@ -220,6 +233,7 @@ def _run_backward_microbatch(
def stage_forward(
action: _Action,
ctx: _PipelineContext,
numerics_logs: Optional[list[str]] = None,
) -> None:
schedule = ctx.schedule_ref
assert isinstance(schedule, _PipelineScheduleRuntime)
Expand Down Expand Up @@ -267,7 +281,9 @@ def stage_forward(
"GraphPPRunner running action %s",
action,
)
output, saved_intermediates = _run_forward_microbatch(stage, *composite_args)
output, saved_intermediates = _run_forward_microbatch(
stage, *composite_args, numerics_logs=numerics_logs
)

# See [Note: pipeline model output type]
output_tuple = _normalize_model_output_as_tuple(output)
Expand Down
67 changes: 67 additions & 0 deletions autoparallel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Iterable

import torch
import torch.utils._pytree as pytree
from torch.distributed._tensor.placement_types import Placement, TensorMeta
from torch.distributed.device_mesh import _get_device_handle
from torch.distributed.tensor._dtensor_spec import DTensorSpec
Expand Down Expand Up @@ -310,3 +313,67 @@ def _get_device_from_mesh(mesh):
return torch.device("cpu")
device_handle = _get_device_handle(mesh.device_type)
return torch.device(mesh.device_type, device_handle.current_device())


# An FX graph interpreter that logs inputs and outputs of each node
# with a few exceptions for c10d ops
class DebugInterpreter(torch.fx.Interpreter):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._logs = []

def log(self, node: str, args: Iterable[Any], inputs_or_outputs: str):
leaves, _ = pytree.tree_flatten(args)
for i, arg in enumerate(leaves):
if not isinstance(arg, torch.Tensor):
self._logs.append(f"{node=}, {inputs_or_outputs}[{i}]={arg}")
continue

if arg.numel() == 0:
self._logs.append(f"{node=}, {inputs_or_outputs}[{i}].numel()=0")
continue

if arg.is_complex():
real = torch.hash_tensor(arg.real)
imag = torch.hash_tensor(arg.imag)
self._logs.append(f"{node=}, {inputs_or_outputs}[{i}], {real=} {imag=}")
continue

self._logs.append(
f"{node=}, {inputs_or_outputs}[{i}]={torch.hash_tensor(arg)}"
)

def run_node(self, n: torch.fx.Node) -> Any:
args, kwargs = self.fetch_args_kwargs_from_env(n)

# reading wait_tensor inputs is undefined behavior
if "wait_tensor" not in n.name:
args, _ = self.fetch_args_kwargs_from_env(n)
self.log(n.name, args, "args")

out = super().run_node(n)

# reading functional collectives outputs before wait_tensor is undefined behavior
if "c10d" not in str(n.target):
outs = out
if isinstance(outs, torch.Tensor):
outs = [outs]
self.log(n.name, outs, "outs")

return out

def get_logs(self):
return self._logs


# Always prints from rank 0 to rank N
def print_rank_by_rank(msg: Any):
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
torch.distributed.barrier()
for i in range(world_size):
if rank == i:
print(f"{rank=} start")
print(msg)
print(f"{rank=} done")
torch.distributed.barrier()
31 changes: 25 additions & 6 deletions examples/example_ds3_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import functools
import logging
import os
from contextlib import nullcontext
from typing import Callable
from typing import Callable, Optional

import torch
import torch.distributed._tools.fake_collectives
Expand Down Expand Up @@ -49,6 +50,7 @@
stage_reshard,
stage_unshard,
)
from autoparallel.utils import print_rank_by_rank

# Configure logging to show DEBUG messages
logging.basicConfig(
Expand Down Expand Up @@ -98,7 +100,7 @@ def build_pipeline_schedule(
return schedule


def run_test(fake_evaluate: bool = True):
def run_test(fake_evaluate: bool, debug_numerics: Optional[bool]):
if not fake_evaluate:
pp_degree = 2
dp_mod_ep_degree = 2
Expand Down Expand Up @@ -346,7 +348,9 @@ def shape_inference_output_fn_last_stage():
input_fn = tracing_input_fn
else:
input_fn = tracing_input_fn_after_first_stage
with AutoParallelPP(stage_mod, input_fn, mesh, dynamic=True) as autop:
with AutoParallelPP(
stage_mod, input_fn, mesh, dynamic=True, compile=False
) as autop:
autop.add_parameter_memory_constraint(low=None, high=None)

# x_sharding = (Shard(0), Replicate())
Expand All @@ -367,7 +371,6 @@ def shape_inference_output_fn_last_stage():
if use_cache:
torch.save(cache, stage_file)

torch.manual_seed(pp_rank)
pp_mod.to_empty(device=device)
pp_mod.init_weights(buffer_device=device)

Expand Down Expand Up @@ -443,7 +446,10 @@ def shape_inference_output_fn_last_stage():
)
assert isinstance(schedule, _PipelineScheduleRuntime)
# Step 6. Override the pipeline runner's action implementations
schedule.register_custom_function(FORWARD, stage_forward)
numerics_logs = []
schedule.register_custom_function(
FORWARD, functools.partial(stage_forward, numerics_logs=numerics_logs)
)
schedule.register_custom_function(FULL_BACKWARD, stage_full_backward)
schedule.register_custom_function(REDUCE_GRAD, stage_reduce_grad)
schedule.register_custom_function(RESHARD, stage_reshard)
Expand All @@ -469,6 +475,9 @@ def shape_inference_output_fn_last_stage():
else:
graph_pp_runner.step()

if debug_numerics:
print_rank_by_rank("\n".join(numerics_logs))

print("All good!")

if torch.distributed.is_initialized():
Expand All @@ -489,6 +498,16 @@ def shape_inference_output_fn_last_stage():
default=False,
help="Use fake evaluation mode with FakeTensorMode (default: False)",
)
parser.add_argument(
"--rng-seed",
type=int,
default=None,
help="Use a specific rng seed and deterministic algorithms for run-to-run invariance (default: None).",
)
args = parser.parse_args()

run_test(fake_evaluate=args.fake_evaluate)
if args.rng_seed is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's say we have 8 ranks in total, they will all initialize their modules. Since each rank initializes a different part of the model, it is hard to compare it with a single rank implementation for numerics debugging. We should have a solution similar to what @wconstab used in torchtitan. Creating a seed checkpoint and using that for PP runs.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By seed checkpoint, do you mean saving and loading random weights generated from a rng seed? I was thinking of just resetting the seed for weights init

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if pp has 8 stages, you would do init_weights for each one of them using the same seed? My concern is how would you compare the pp_runtime with spmd only for numerics?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we cut our stages at nm module boundaries, and init weights in the same order, we could reset the seeds at the same cuts during the spmd init weights.

For supporting arbitrary stage splits, I would need to know more about how we would implement their init_weights and checkpointing. So I put that aside for now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also add an example that saves the params after init and the grads after after accumulating grads by running microbatches in spmd? Analogously, pp also saves the params after init and grads after running the step and finally a script that compares both?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, I'm changing up example_ds3_local_map.py to use real tensors to be the SPMD microbatch + accumulate grad steps baseline. And I have a script to diff the outputs of DebugInterpreter that I was thinking of landing separately from this PR.

torch.use_deterministic_algorithms(True)
torch.manual_seed(args.rng_seed)

run_test(fake_evaluate=args.fake_evaluate, debug_numerics=args.rng_seed is not None)
Loading