Skip to content

Commit 9815a8c

Browse files
committed
Add option for run-to-run deterministic and add optional numerics logging
stack-info: PR: #235, branch: xmfan/stack/16
1 parent b042f26 commit 9815a8c

File tree

4 files changed

+118
-20
lines changed

4 files changed

+118
-20
lines changed

autoparallel/_testing/models/dsv3.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,7 +1062,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
10621062

10631063
# HOPs don't support buffer mutations, keep this outside
10641064
with torch.no_grad():
1065-
self.tokens_per_expert.add_(num_tokens_per_expert)
1065+
self.tokens_per_expert.add_(num_tokens_per_expert) # type: ignore[operator]
10661066
return out
10671067

10681068
def init_weights(
@@ -1076,14 +1076,10 @@ def init_weights(
10761076
self.shared_experts.init_weights(init_std)
10771077

10781078
with torch.device(buffer_device):
1079-
self.tokens_per_expert = torch.zeros(
1080-
self.experts.num_experts, dtype=torch.float32
1081-
)
1079+
self.tokens_per_expert.zero_() # type: ignore[operator]
10821080
if self.load_balance_coeff is not None:
10831081
assert isinstance(self.expert_bias, torch.Tensor)
1084-
self.expert_bias = torch.zeros(
1085-
self.experts.num_experts, dtype=torch.float32
1086-
)
1082+
self.expert_bias.zero_() # type: ignore[operator]
10871083

10881084

10891085
def has_cuda_capability(major: int, minor: int) -> bool:

autoparallel/graph_pp_runner.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import logging
77
from dataclasses import dataclass
8-
from typing import Any, Callable, cast, Optional, Union
8+
from typing import Any, Callable, Optional, Union, cast
99

1010
import torch
1111
import torch.fx as fx
@@ -16,11 +16,13 @@
1616
_wait_batch_p2p,
1717
)
1818
from torch.distributed.pipelining.stage import (
19-
_normalize_model_output_as_tuple,
2019
PipelineStage,
20+
_normalize_model_output_as_tuple,
2121
)
2222
from torch.distributed.tensor import DTensor
2323

24+
from autoparallel.utils import DebugInterpreter
25+
2426
logger = logging.getLogger(__name__)
2527
logger.setLevel(logging.DEBUG)
2628

@@ -97,12 +99,21 @@ def scale_grads(self, grad_scale_factor: int) -> None:
9799

98100

99101
def _run_fw_module(
100-
fw_module: fx.GraphModule, graph_meta: GraphMeta, fw_args: list[Any]
102+
fw_module: fx.GraphModule,
103+
graph_meta: GraphMeta,
104+
fw_args: list[Any],
105+
numerics_logs: Optional[list[str]] = None,
101106
) -> tuple[Any, tuple[list[Any], list[Any]]]:
102107
assert len([n for n in fw_module.graph.nodes if n.op == "placeholder"]) == len(
103108
fw_args
104109
), f"Mismatched number of inputs to fwd, {len([n for n in fw_module.graph.nodes if n.op == 'placeholder'])}, {len(fw_args)}"
105-
fw_outputs = torch.fx.Interpreter(fw_module).boxed_run(fw_args)
110+
if numerics_logs is not None:
111+
debug_interpreter = DebugInterpreter(fw_module)
112+
fw_outputs = debug_interpreter.boxed_run(fw_args)
113+
numerics_logs += debug_interpreter.get_logs()
114+
else:
115+
fw_outputs = torch.fx.Interpreter(fw_module).boxed_run(fw_args)
116+
106117
num_inner_fwd_outputs = graph_meta.num_mutate_inputs + graph_meta.num_user_outputs
107118
saved_intermediates = fw_outputs[num_inner_fwd_outputs:]
108119
num_tensors_for_backward = (
@@ -173,14 +184,16 @@ def _run_reduce_grad_module(
173184
return sharded_grads
174185

175186

176-
def _run_forward_microbatch(stage: GraphPipelineStage, *args) -> tuple[Any, Any]:
187+
def _run_forward_microbatch(
188+
stage: GraphPipelineStage, *args, numerics_logs: Optional[list[str]] = None
189+
) -> tuple[Any, Any]:
177190
fw_args = [
178191
*stage.state["unsharded_params"],
179192
*stage.state["buffers"],
180193
*args,
181194
]
182195
user_outputs, saved_intermediates = _run_fw_module(
183-
stage.graph_callables.fw, stage.graph_meta, fw_args
196+
stage.graph_callables.fw, stage.graph_meta, fw_args, numerics_logs=numerics_logs
184197
)
185198
return (user_outputs, saved_intermediates)
186199

@@ -220,6 +233,7 @@ def _run_backward_microbatch(
220233
def stage_forward(
221234
action: _Action,
222235
ctx: _PipelineContext,
236+
numerics_logs: Optional[list[str]] = None,
223237
) -> None:
224238
schedule = ctx.schedule_ref
225239
assert isinstance(schedule, _PipelineScheduleRuntime)
@@ -267,7 +281,9 @@ def stage_forward(
267281
"GraphPPRunner running action %s",
268282
action,
269283
)
270-
output, saved_intermediates = _run_forward_microbatch(stage, *composite_args)
284+
output, saved_intermediates = _run_forward_microbatch(
285+
stage, *composite_args, numerics_logs=numerics_logs
286+
)
271287

272288
# See [Note: pipeline model output type]
273289
output_tuple = _normalize_model_output_as_tuple(output)

autoparallel/utils.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33
# This source code is licensed under the BSD license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
from typing import Any, Iterable
7+
68
import torch
9+
import torch.utils._pytree as pytree
710
from torch.distributed._tensor.placement_types import Placement, TensorMeta
811
from torch.distributed.device_mesh import _get_device_handle
912
from torch.distributed.tensor._dtensor_spec import DTensorSpec
@@ -310,3 +313,67 @@ def _get_device_from_mesh(mesh):
310313
return torch.device("cpu")
311314
device_handle = _get_device_handle(mesh.device_type)
312315
return torch.device(mesh.device_type, device_handle.current_device())
316+
317+
318+
# An FX graph interpreter that logs inputs and outputs of each node
319+
# with a few exceptions for c10d ops
320+
class DebugInterpreter(torch.fx.Interpreter):
321+
def __init__(self, *args, **kwargs):
322+
super().__init__(*args, **kwargs)
323+
self._logs = []
324+
325+
def log(self, node: str, args: Iterable[Any], inputs_or_outputs: str):
326+
leaves, _ = pytree.tree_flatten(args)
327+
for i, arg in enumerate(leaves):
328+
if not isinstance(arg, torch.Tensor):
329+
self._logs.append(f"{node=}, {inputs_or_outputs}[{i}]={arg}")
330+
continue
331+
332+
if arg.numel() == 0:
333+
self._logs.append(f"{node=}, {inputs_or_outputs}[{i}].numel()=0")
334+
continue
335+
336+
if arg.is_complex():
337+
real = torch.hash_tensor(arg.real)
338+
imag = torch.hash_tensor(arg.imag)
339+
self._logs.append(f"{node=}, {inputs_or_outputs}[{i}], {real=} {imag=}")
340+
continue
341+
342+
self._logs.append(
343+
f"{node=}, {inputs_or_outputs}[{i}]={torch.hash_tensor(arg)}"
344+
)
345+
346+
def run_node(self, n: torch.fx.Node) -> Any:
347+
args, kwargs = self.fetch_args_kwargs_from_env(n)
348+
349+
# reading wait_tensor inputs is undefined behavior
350+
if "wait_tensor" not in n.name:
351+
args, _ = self.fetch_args_kwargs_from_env(n)
352+
self.log(n.name, args, "args")
353+
354+
out = super().run_node(n)
355+
356+
# reading functional collectives outputs before wait_tensor is undefined behavior
357+
if "c10d" not in str(n.target):
358+
outs = out
359+
if isinstance(outs, torch.Tensor):
360+
outs = [outs]
361+
self.log(n.name, outs, "outs")
362+
363+
return out
364+
365+
def get_logs(self):
366+
return self._logs
367+
368+
369+
# Always prints from rank 0 to rank N
370+
def print_rank_by_rank(msg: Any):
371+
rank = torch.distributed.get_rank()
372+
world_size = torch.distributed.get_world_size()
373+
torch.distributed.barrier()
374+
for i in range(world_size):
375+
if rank == i:
376+
print(f"{rank=} start")
377+
print(msg)
378+
print(f"{rank=} done")
379+
torch.distributed.barrier()

examples/example_ds3_pp.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
# This source code is licensed under the BSD license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import functools
67
import logging
78
import os
89
from contextlib import nullcontext
9-
from typing import Callable
10+
from typing import Callable, Optional
1011

1112
import torch
1213
import torch.distributed._tools.fake_collectives
@@ -49,6 +50,7 @@
4950
stage_reshard,
5051
stage_unshard,
5152
)
53+
from autoparallel.utils import print_rank_by_rank
5254

5355
# Configure logging to show DEBUG messages
5456
logging.basicConfig(
@@ -98,7 +100,7 @@ def build_pipeline_schedule(
98100
return schedule
99101

100102

101-
def run_test(fake_evaluate: bool = True):
103+
def run_test(fake_evaluate: bool, debug_numerics: Optional[bool]):
102104
if not fake_evaluate:
103105
pp_degree = 2
104106
dp_mod_ep_degree = 2
@@ -346,7 +348,9 @@ def shape_inference_output_fn_last_stage():
346348
input_fn = tracing_input_fn
347349
else:
348350
input_fn = tracing_input_fn_after_first_stage
349-
with AutoParallelPP(stage_mod, input_fn, mesh, dynamic=True) as autop:
351+
with AutoParallelPP(
352+
stage_mod, input_fn, mesh, dynamic=True, compile=False
353+
) as autop:
350354
autop.add_parameter_memory_constraint(low=None, high=None)
351355

352356
# x_sharding = (Shard(0), Replicate())
@@ -367,7 +371,6 @@ def shape_inference_output_fn_last_stage():
367371
if use_cache:
368372
torch.save(cache, stage_file)
369373

370-
torch.manual_seed(pp_rank)
371374
pp_mod.to_empty(device=device)
372375
pp_mod.init_weights(buffer_device=device)
373376

@@ -443,7 +446,10 @@ def shape_inference_output_fn_last_stage():
443446
)
444447
assert isinstance(schedule, _PipelineScheduleRuntime)
445448
# Step 6. Override the pipeline runner's action implementations
446-
schedule.register_custom_function(FORWARD, stage_forward)
449+
numerics_logs = []
450+
schedule.register_custom_function(
451+
FORWARD, functools.partial(stage_forward, numerics_logs=numerics_logs)
452+
)
447453
schedule.register_custom_function(FULL_BACKWARD, stage_full_backward)
448454
schedule.register_custom_function(REDUCE_GRAD, stage_reduce_grad)
449455
schedule.register_custom_function(RESHARD, stage_reshard)
@@ -469,6 +475,9 @@ def shape_inference_output_fn_last_stage():
469475
else:
470476
graph_pp_runner.step()
471477

478+
if debug_numerics:
479+
print_rank_by_rank("\n".join(numerics_logs))
480+
472481
print("All good!")
473482

474483
if torch.distributed.is_initialized():
@@ -489,6 +498,16 @@ def shape_inference_output_fn_last_stage():
489498
default=False,
490499
help="Use fake evaluation mode with FakeTensorMode (default: False)",
491500
)
501+
parser.add_argument(
502+
"--rng-seed",
503+
type=int,
504+
default=None,
505+
help="Use a specific rng seed and deterministic algorithms for run-to-run invariance (default: None).",
506+
)
492507
args = parser.parse_args()
493508

494-
run_test(fake_evaluate=args.fake_evaluate)
509+
if args.rng_seed is not None:
510+
torch.use_deterministic_algorithms(True)
511+
torch.manual_seed(args.rng_seed)
512+
513+
run_test(fake_evaluate=args.fake_evaluate, debug_numerics=args.rng_seed is not None)

0 commit comments

Comments
 (0)