-
Notifications
You must be signed in to change notification settings - Fork 8
Add option for run-to-run deterministic and add optional numerics logging #235
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…ging stack-info: PR: #235, branch: xmfan/stack/16
…ging stack-info: PR: #235, branch: xmfan/stack/16
…ging stack-info: PR: #235, branch: xmfan/stack/16
| numerics_logs += debug_interpreter.get_logs() | ||
| else: | ||
| fw_outputs = torch.fx.Interpreter(fw_module).boxed_run(fw_args) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should add this to all the graph module calls
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you mean the backward? I didn't add it since I couldn't test it on the base commit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, once I land #237 we can add it for full_bw, bw_dI, bw_dW, unshard and reduce_grad.
| args = parser.parse_args() | ||
|
|
||
| run_test(fake_evaluate=args.fake_evaluate) | ||
| if args.rng_seed is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's say we have 8 ranks in total, they will all initialize their modules. Since each rank initializes a different part of the model, it is hard to compare it with a single rank implementation for numerics debugging. We should have a solution similar to what @wconstab used in torchtitan. Creating a seed checkpoint and using that for PP runs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By seed checkpoint, do you mean saving and loading random weights generated from a rng seed? I was thinking of just resetting the seed for weights init
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So if pp has 8 stages, you would do init_weights for each one of them using the same seed? My concern is how would you compare the pp_runtime with spmd only for numerics?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we cut our stages at nm module boundaries, and init weights in the same order, we could reset the seeds at the same cuts during the spmd init weights.
For supporting arbitrary stage splits, I would need to know more about how we would implement their init_weights and checkpointing. So I put that aside for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you also add an example that saves the params after init and the grads after after accumulating grads by running microbatches in spmd? Analogously, pp also saves the params after init and grads after running the step and finally a script that compares both?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup, I'm changing up example_ds3_local_map.py to use real tensors to be the SPMD microbatch + accumulate grad steps baseline. And I have a script to diff the outputs of DebugInterpreter that I was thinking of landing separately from this PR.
sanketpurandare
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just land after #237 and add the DebugInterpreter to other graph_module calls if required.
…ging stack-info: PR: #235, branch: xmfan/stack/16
Stacked PRs:
Add option for run-to-run deterministic and add optional numerics logging
Example log:
tlp torchrun --standalone --nproc-per-node 8 examples/example_ds3_pp.py --rng-seed 1234 | pastry; tlp torchrun --standalone --nproc-per-node 8 examples/example_ds3_pp.py --rng-seed 1234 | pastryhttps://www.internalfb.com/intern/diffing/?paste_number=2027323783