Skip to content

Commit 45287c4

Browse files
committed
Add option for run-to-run deterministic and add optional numerics logging
stack-info: PR: #235, branch: xmfan/stack/16
1 parent 90cd287 commit 45287c4

File tree

4 files changed

+116
-18
lines changed

4 files changed

+116
-18
lines changed

autoparallel/_testing/models/dsv3.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,7 +1062,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
10621062

10631063
# HOPs don't support buffer mutations, keep this outside
10641064
with torch.no_grad():
1065-
self.tokens_per_expert.add_(num_tokens_per_expert)
1065+
self.tokens_per_expert.add_(num_tokens_per_expert) # type: ignore[operator]
10661066
return out
10671067

10681068
def init_weights(
@@ -1076,14 +1076,10 @@ def init_weights(
10761076
self.shared_experts.init_weights(init_std)
10771077

10781078
with torch.device(buffer_device):
1079-
self.tokens_per_expert = torch.zeros(
1080-
self.experts.num_experts, dtype=torch.float32
1081-
)
1079+
self.tokens_per_expert.zero_()
10821080
if self.load_balance_coeff is not None:
10831081
assert isinstance(self.expert_bias, torch.Tensor)
1084-
self.expert_bias = torch.zeros(
1085-
self.experts.num_experts, dtype=torch.float32
1086-
)
1082+
self.expert_bias.zero_()
10871083

10881084

10891085
def has_cuda_capability(major: int, minor: int) -> bool:

autoparallel/graph_pp_runner.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
)
2121
from torch.distributed.tensor import DTensor
2222

23+
from autoparallel.utils import DebugInterpreter
24+
2325

2426
@dataclass
2527
class GraphCallables:
@@ -77,12 +79,21 @@ def __init__(
7779

7880

7981
def _run_fw_module(
80-
fw_module: fx.GraphModule, graph_meta: GraphMeta, fw_args: list[Any]
82+
fw_module: fx.GraphModule,
83+
graph_meta: GraphMeta,
84+
fw_args: list[Any],
85+
numerics_logs: Optional[list[str]] = None,
8186
) -> tuple[Any, tuple[list[Any], list[Any]]]:
8287
assert len([n for n in fw_module.graph.nodes if n.op == "placeholder"]) == len(
8388
fw_args
8489
), f"Mismatched number of inputs to fwd, {len([n for n in fw_module.graph.nodes if n.op == 'placeholder'])}, {len(fw_args)}"
85-
fw_outputs = torch.fx.Interpreter(fw_module).boxed_run(fw_args)
90+
if numerics_logs is not None:
91+
debug_interpreter = DebugInterpreter(fw_module)
92+
fw_outputs = debug_interpreter.boxed_run(fw_args)
93+
numerics_logs += debug_interpreter.get_logs()
94+
else:
95+
fw_outputs = torch.fx.Interpreter(fw_module).boxed_run(fw_args)
96+
8697
num_inner_fwd_outputs = graph_meta.num_mutate_inputs + graph_meta.num_user_outputs
8798
saved_intermediates = fw_outputs[num_inner_fwd_outputs:]
8899
num_tensors_for_backward = (
@@ -153,14 +164,16 @@ def _run_reduce_grad_module(
153164
return sharded_grads
154165

155166

156-
def _run_forward_microbatch(stage: GraphPipelineStage, *args) -> tuple[Any, Any]:
167+
def _run_forward_microbatch(
168+
stage: GraphPipelineStage, *args, numerics_logs: Optional[list[str]] = None
169+
) -> tuple[Any, Any]:
157170
fw_args = [
158171
*stage.state["unsharded_params"],
159172
*stage.state["buffers"],
160173
*args,
161174
]
162175
user_outputs, saved_intermediates = _run_fw_module(
163-
stage.graph_callables.fw, stage.graph_meta, fw_args
176+
stage.graph_callables.fw, stage.graph_meta, fw_args, numerics_logs=numerics_logs
164177
)
165178
return (user_outputs, saved_intermediates)
166179

@@ -200,6 +213,7 @@ def _run_backward_microbatch(
200213
def stage_forward(
201214
action: _Action,
202215
ctx: _PipelineContext,
216+
numerics_logs: Optional[list[str]] = None,
203217
) -> None:
204218
schedule = ctx.schedule_ref
205219
assert isinstance(schedule, _PipelineScheduleRuntime)
@@ -244,7 +258,9 @@ def stage_forward(
244258

245259
# stage._validate_fwd_input(args, kwargs) Maybe need to validate composite args?
246260

247-
output, saved_intermediates = _run_forward_microbatch(stage, *composite_args)
261+
output, saved_intermediates = _run_forward_microbatch(
262+
stage, *composite_args, numerics_logs=numerics_logs
263+
)
248264

249265
# See [Note: pipeline model output type]
250266
output_tuple = _normalize_model_output_as_tuple(output)

autoparallel/utils.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33
# This source code is licensed under the BSD license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
from typing import Any, Iterable
7+
68
import torch
9+
import torch.utils._pytree as pytree
710
from torch.distributed._tensor.placement_types import Placement, TensorMeta
811
from torch.distributed.device_mesh import _get_device_handle
912
from torch.distributed.tensor._dtensor_spec import DTensorSpec
@@ -310,3 +313,67 @@ def _get_device_from_mesh(mesh):
310313
return torch.device("cpu")
311314
device_handle = _get_device_handle(mesh.device_type)
312315
return torch.device(mesh.device_type, device_handle.current_device())
316+
317+
318+
# An FX graph interpreter that logs inputs and outputs of each node
319+
# with a few exceptions for c10d ops
320+
class DebugInterpreter(torch.fx.Interpreter):
321+
def __init__(self, *args, **kwargs):
322+
super().__init__(*args, **kwargs)
323+
self._logs = []
324+
325+
def log(self, node: str, args: Iterable[Any], inputs_or_outputs: str):
326+
leaves, _ = pytree.tree_flatten(args)
327+
for i, arg in enumerate(leaves):
328+
if not isinstance(arg, torch.Tensor):
329+
self._logs.append(f"{node=}, {inputs_or_outputs}[{i}]={arg}")
330+
continue
331+
332+
if arg.numel() == 0:
333+
self._logs.append(f"{node=}, {inputs_or_outputs}[{i}].numel()=0")
334+
continue
335+
336+
if arg.is_complex():
337+
real = torch.hash_tensor(arg.real)
338+
imag = torch.hash_tensor(arg.imag)
339+
self._logs.append(f"{node=}, {inputs_or_outputs}[{i}], {real=} {imag=}")
340+
continue
341+
342+
self._logs.append(
343+
f"{node=}, {inputs_or_outputs}[{i}]={torch.hash_tensor(arg)}"
344+
)
345+
346+
def run_node(self, n: torch.fx.Node) -> Any:
347+
args, kwargs = self.fetch_args_kwargs_from_env(n)
348+
349+
# reading wait_tensor inputs is undefined behavior
350+
if "wait_tensor" not in n.name:
351+
args, _ = self.fetch_args_kwargs_from_env(n)
352+
self.log(n.name, args, "args")
353+
354+
out = super().run_node(n)
355+
356+
# reading functional collectives outputs before wait_tensor is undefined behavior
357+
if "c10d" not in str(n.target):
358+
outs = out
359+
if isinstance(outs, torch.Tensor):
360+
outs = [outs]
361+
self.log(n.name, outs, "outs")
362+
363+
return out
364+
365+
def get_logs(self):
366+
return self._logs
367+
368+
369+
# Always prints from rank 0 to rank N
370+
def print_rank_by_rank(msg: Any):
371+
rank = torch.distributed.get_rank()
372+
world_size = torch.distributed.get_world_size()
373+
torch.distributed.barrier()
374+
for i in range(world_size):
375+
if rank == i:
376+
print(f"{rank=} start")
377+
print(msg)
378+
print(f"{rank=} done")
379+
torch.distributed.barrier()

examples/example_ds3_pp.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
# This source code is licensed under the BSD license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import functools
67
import logging
78
import os
89
from contextlib import nullcontext
9-
from typing import Callable
10+
from typing import Callable, Optional
1011

1112
import torch
1213
import torch.distributed._tools.fake_collectives
@@ -43,6 +44,7 @@
4344
stage_forward,
4445
stage_full_backward,
4546
)
47+
from autoparallel.utils import print_rank_by_rank
4648

4749
logger = logging.getLogger(__name__)
4850

@@ -86,7 +88,7 @@ def build_pipeline_schedule(
8688
return schedule
8789

8890

89-
def run_test(fake_evaluate: bool = True):
91+
def run_test(fake_evaluate: bool, debug_numerics: Optional[bool]):
9092
if not fake_evaluate:
9193
pp_degree = 2
9294
dp_mod_ep_degree = 2
@@ -334,7 +336,9 @@ def shape_inference_output_fn_last_stage():
334336
input_fn = tracing_input_fn
335337
else:
336338
input_fn = tracing_input_fn_after_first_stage
337-
with AutoParallelPP(stage_mod, input_fn, mesh, dynamic=True) as autop:
339+
with AutoParallelPP(
340+
stage_mod, input_fn, mesh, dynamic=True, compile=False
341+
) as autop:
338342
autop.add_parameter_memory_constraint(low=None, high=None)
339343

340344
# x_sharding = (Shard(0), Replicate())
@@ -355,7 +359,6 @@ def shape_inference_output_fn_last_stage():
355359
if use_cache:
356360
torch.save(cache, stage_file)
357361

358-
torch.manual_seed(pp_rank)
359362
pp_mod.to_empty(device=device)
360363
pp_mod.init_weights(buffer_device=device)
361364

@@ -430,7 +433,10 @@ def shape_inference_output_fn_last_stage():
430433
)
431434
assert isinstance(schedule, _PipelineScheduleRuntime)
432435
# Step 6. Override the pipeline runner's action implementations
433-
schedule.register_custom_function(FORWARD, stage_forward)
436+
numerics_logs = []
437+
schedule.register_custom_function(
438+
FORWARD, functools.partial(stage_forward, numerics_logs=numerics_logs)
439+
)
434440
schedule.register_custom_function(FULL_BACKWARD, stage_full_backward)
435441

436442
# Step 7. Register the schedule with the graph runner
@@ -453,6 +459,9 @@ def shape_inference_output_fn_last_stage():
453459
else:
454460
graph_pp_runner.step()
455461

462+
if debug_numerics:
463+
print_rank_by_rank("\n".join(numerics_logs))
464+
456465
print("All good!")
457466

458467
if torch.distributed.is_initialized():
@@ -473,6 +482,16 @@ def shape_inference_output_fn_last_stage():
473482
default=False,
474483
help="Use fake evaluation mode with FakeTensorMode (default: False)",
475484
)
485+
parser.add_argument(
486+
"--rng-seed",
487+
type=int,
488+
default=None,
489+
help="Use a specific rng seed and deterministic algorithms for run-to-run invariance (default: None).",
490+
)
476491
args = parser.parse_args()
477492

478-
run_test(fake_evaluate=args.fake_evaluate)
493+
if args.rng_seed is not None:
494+
torch.use_deterministic_algorithms(True)
495+
torch.manual_seed(args.rng_seed)
496+
497+
run_test(fake_evaluate=args.fake_evaluate, debug_numerics=args.rng_seed is not None)

0 commit comments

Comments
 (0)