-
Notifications
You must be signed in to change notification settings - Fork 8
Add option for run-to-run deterministic and add optional numerics logging #235
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,10 +3,11 @@ | |
| # This source code is licensed under the BSD license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import functools | ||
| import logging | ||
| import os | ||
| from contextlib import nullcontext | ||
| from typing import Callable | ||
| from typing import Callable, Optional | ||
|
|
||
| import torch | ||
| import torch.distributed._tools.fake_collectives | ||
|
|
@@ -49,6 +50,7 @@ | |
| stage_reshard, | ||
| stage_unshard, | ||
| ) | ||
| from autoparallel.utils import print_rank_by_rank | ||
|
|
||
| # Configure logging to show DEBUG messages | ||
| logging.basicConfig( | ||
|
|
@@ -98,7 +100,7 @@ def build_pipeline_schedule( | |
| return schedule | ||
|
|
||
|
|
||
| def run_test(fake_evaluate: bool = True): | ||
| def run_test(fake_evaluate: bool, debug_numerics: Optional[bool]): | ||
| if not fake_evaluate: | ||
| pp_degree = 2 | ||
| dp_mod_ep_degree = 2 | ||
|
|
@@ -346,7 +348,9 @@ def shape_inference_output_fn_last_stage(): | |
| input_fn = tracing_input_fn | ||
| else: | ||
| input_fn = tracing_input_fn_after_first_stage | ||
| with AutoParallelPP(stage_mod, input_fn, mesh, dynamic=True) as autop: | ||
| with AutoParallelPP( | ||
| stage_mod, input_fn, mesh, dynamic=True, compile=False | ||
| ) as autop: | ||
| autop.add_parameter_memory_constraint(low=None, high=None) | ||
|
|
||
| # x_sharding = (Shard(0), Replicate()) | ||
|
|
@@ -367,7 +371,6 @@ def shape_inference_output_fn_last_stage(): | |
| if use_cache: | ||
| torch.save(cache, stage_file) | ||
|
|
||
| torch.manual_seed(pp_rank) | ||
| pp_mod.to_empty(device=device) | ||
| pp_mod.init_weights(buffer_device=device) | ||
|
|
||
|
|
@@ -443,7 +446,10 @@ def shape_inference_output_fn_last_stage(): | |
| ) | ||
| assert isinstance(schedule, _PipelineScheduleRuntime) | ||
| # Step 6. Override the pipeline runner's action implementations | ||
| schedule.register_custom_function(FORWARD, stage_forward) | ||
| numerics_logs = [] | ||
| schedule.register_custom_function( | ||
| FORWARD, functools.partial(stage_forward, numerics_logs=numerics_logs) | ||
| ) | ||
| schedule.register_custom_function(FULL_BACKWARD, stage_full_backward) | ||
| schedule.register_custom_function(REDUCE_GRAD, stage_reduce_grad) | ||
| schedule.register_custom_function(RESHARD, stage_reshard) | ||
|
|
@@ -469,6 +475,9 @@ def shape_inference_output_fn_last_stage(): | |
| else: | ||
| graph_pp_runner.step() | ||
|
|
||
| if debug_numerics: | ||
| print_rank_by_rank("\n".join(numerics_logs)) | ||
|
|
||
| print("All good!") | ||
|
|
||
| if torch.distributed.is_initialized(): | ||
|
|
@@ -489,6 +498,16 @@ def shape_inference_output_fn_last_stage(): | |
| default=False, | ||
| help="Use fake evaluation mode with FakeTensorMode (default: False)", | ||
| ) | ||
| parser.add_argument( | ||
| "--rng-seed", | ||
| type=int, | ||
| default=None, | ||
| help="Use a specific rng seed and deterministic algorithms for run-to-run invariance (default: None).", | ||
| ) | ||
| args = parser.parse_args() | ||
|
|
||
| run_test(fake_evaluate=args.fake_evaluate) | ||
| if args.rng_seed is not None: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's say we have 8 ranks in total, they will all initialize their modules. Since each rank initializes a different part of the model, it is hard to compare it with a single rank implementation for numerics debugging. We should have a solution similar to what @wconstab used in torchtitan. Creating a seed checkpoint and using that for PP runs.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. By seed checkpoint, do you mean saving and loading random weights generated from a rng seed? I was thinking of just resetting the seed for weights init
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So if pp has 8 stages, you would do init_weights for each one of them using the same seed? My concern is how would you compare the pp_runtime with spmd only for numerics?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we cut our stages at nm module boundaries, and init weights in the same order, we could reset the seeds at the same cuts during the spmd init weights. For supporting arbitrary stage splits, I would need to know more about how we would implement their init_weights and checkpointing. So I put that aside for now.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you also add an example that saves the params after init and the grads after after accumulating grads by running microbatches in spmd? Analogously, pp also saves the params after init and grads after running the step and finally a script that compares both?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yup, I'm changing up example_ds3_local_map.py to use real tensors to be the SPMD microbatch + accumulate grad steps baseline. And I have a script to diff the outputs of DebugInterpreter that I was thinking of landing separately from this PR. |
||
| torch.use_deterministic_algorithms(True) | ||
| torch.manual_seed(args.rng_seed) | ||
|
|
||
| run_test(fake_evaluate=args.fake_evaluate, debug_numerics=args.rng_seed is not None) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should add this to all the graph module calls
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you mean the backward? I didn't add it since I couldn't test it on the base commit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, once I land #237 we can add it for full_bw, bw_dI, bw_dW, unshard and reduce_grad.