-
Notifications
You must be signed in to change notification settings - Fork 8
use passed in model's init_weights fn instead of random weight init #20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
fmassa
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thanks!
examples/example_autoparallel.py
Outdated
|
|
||
| def init_weights(self): | ||
| for lin in [self.wq, self.wk, self.wv, self.wo, self.w1, self.w2]: | ||
| breakpoint() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: remove?
autoparallel/api.py
Outdated
| sharded_buffers = try_convert_fake_to_real(sharded_buffers_no_fqns) | ||
| # Right now we require a convention that the user model provides an init_weights method, | ||
| # although we could snoop for other methods too. | ||
| if hasattr(self.model, "init_weights") is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
question-
Given that we're specializing on the signature of orig_model.init_weights' being the init fn anyway, do you think it would be worth creating an init_weights` fn on self.parallel_model so that the initialization can be deferred until the user calls it?
This would solve one specific nit for my titan integration, namely that in titan the code goes like
model = ...
model = parallelize_fn(model) # auto_p in here
model.cuda()
model.init_weights() # <-- even if we initialize automatically inside autop, i'd still have to special case titan's train.py to not call it again here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh yeah. If we are ok in the medium term with requiring a user-provided init_weights, and the usual flow is that users will call init_weights themselves anyway, then seems reasonable to let the user call this
I was about to send the above but I'm not actually sure how we would do what you described, unless we ensure that the returned parallel_module has the exact same nn.module hierarchy as the original user model. Since the user's init_weights function will probably have some recursive logic that walks the hierarchy.
We could potentially try to do that, but I'm worried it'll end up being a lot more complicated than the much simpler "keep graph flattened but make sure the FQNs are accurate" plan we have today. wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what I meant was basically to copy your code snippet into a lambda and stick that on parallel_mod as an init_weights method. Not to use the user's original init_weights method. Does this make sense? maybe i can try it and see if it works.
|
how did you validate this on torchtitan?
for debugmodel i saw a huge loss value when using this version of init (and stays high). without your changes, the same command gives me a lower initial loss (and decreasing over time) |
|
@wconstab I just updated the PR to make I tried running my changes with a patched version of your torchtitan+autop integration, and I'm seeing "good" loss. Do you think there's something I'm missing in my repro? I tried tweaking your torchtitan branch to do this: |
| # assign an init_weights method onto the output mod. | ||
| # all it does is sneakily run the original user mod's init_weights method, | ||
| # but with our new DTensor sharded params attached to the user module. | ||
| self.parallel_model.init_weights = init_weights |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm ok with this for now, but then we would need to clearly specify the contract of what methods are propagated from the original model vs not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed. if we want to be extreme, we could make this a strict requirement and error out if the user's mod doesn't have a init_weights method. Or maybe once we get further along adoption-wise we should just write some docs that clearly spell out the restrictions / user requirements, this being one of them?
|
I'm trying the latest PR again now
note that in your paste above, the loss value is exactly the same on each step. I don't think this is OK. However I did previously see totally exploding losses, which I no longer see. Instead, it almost seems like the parameters are "frozen at some random value forever". Here is what I get on debugmodel without your change: and here's what I get with this PR at 03b4c24 and this change to torchtitan train.py |
|
ok update - after a small tweak to this PR, I realized the other problem is that torchtitan uses This ends up playing poorly with the param init, because: (1) With my updated PR, I now get what looks like reasonable loss with this titan change: loss: |
|
this works for me now. |
|
I ended up updating so that: (1) to_empty() from the user works, our init_weights fn now just closes over the state dict so we can grab whatever the user dumps in there and initialize it (2) i also had to tweak autoparallel to return a state dict containing DTensor(meta_tensor) instead of DTensor(FakeTensor), since with fake tensor, the |
|
and I confirmed reasonable loss on titan with the debug model with this tweak: loss |
|
testing locally and will merge this and torchtitan/autoparallel changes once verified. thanks! |
autoparallel/api.py
Outdated
|
|
||
| # Right now we require a convention that the user model provides an init_weights method, | ||
| # although we could snoop for other methods too. | ||
| if hasattr(self.model, "init_weights") is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if hasattr(self.model, "init_weights") is not None:
shouldn't this be just
if hasattr(self.model, "init_weights"):?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pushed a fix.
Relying on meta-pytorch/autoparallel#20, this lets us automatically apply a user's init_weights fn to the autoparallel model. Verified this works with `CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4 --training.dataset c4` ``` [rank0]:[titan] 2025-07-02 16:18:02,007 - root - INFO - Training starts at step 1. [rank0]:[titan] 2025-07-02 16:18:08,224 - root - INFO - step: 1 loss: 8.1848 memory: 1.09GiB(1.14%) tps: 77 tflops: 0.01 mfu: 0.00% [rank0]:[titan] 2025-07-02 16:18:08,224 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2025-07-02 16:18:08,310 - root - INFO - step: 2 loss: 8.1619 memory: 1.15GiB(1.21%) tps: 48,138 tflops: 3.46 mfu: 0.35 % [rank0]:[titan] 2025-07-02 16:18:08,356 - root - INFO - step: 3 loss: 8.1140 memory: 1.15GiB(1.21%) tps: 88,440 tflops: 6.36 mfu: 0.64 % [rank0]:[titan] 2025-07-02 16:18:08,406 - root - INFO - step: 4 loss: 8.0099 memory: 1.15GiB(1.21%) tps: 82,626 tflops: 5.94 mfu: 0.60 % [rank0]:[titan] 2025-07-02 16:18:08,457 - root - INFO - step: 5 loss: 7.8928 memory: 1.15GiB(1.21%) tps: 81,594 tflops: 5.87 mfu: 0.59 % [rank0]:[titan] 2025-07-02 16:18:08,508 - root - INFO - step: 6 loss: 7.7758 memory: 1.15GiB(1.21%) tps: 79,607 tflops: 5.72 mfu: 0.58 % [rank0]:[titan] 2025-07-02 16:18:08,559 - root - INFO - step: 7 loss: 7.6221 memory: 1.15GiB(1.21%) tps: 81,448 tflops: 5.86 mfu: 0.59 % [rank0]:[titan] 2025-07-02 16:18:08,611 - root - INFO - step: 8 loss: 7.5578 memory: 1.15GiB(1.21%) tps: 79,732 tflops: 5.73 mfu: 0.58 % [rank0]:[titan] 2025-07-02 16:18:08,659 - root - INFO - step: 9 loss: 7.3851 memory: 1.15GiB(1.21%) tps: 85,655 tflops: 6.16 mfu: 0.62 % [rank0]:[titan] 2025-07-02 16:18:08,709 - root - INFO - step: 10 loss: 7.3361 memory: 1.15GiB(1.21%) tps: 81,855 tflops: 5.89 mfu: 0.60 % [rank0]:[titan] 2025-07-02 16:18:08,709 - root - INFO - Sleeping 2 seconds for other ranks to complete ```
TODO - try converting model params into fake tensors - figure out init fn - integrate torchtitan configs for DP/TP to control autop Hack an init_fn for llama3 and observe loss decreasing with autoparallel """ [rank0]:[titan] 2025-06-16 16:24:16,593 - root - INFO - Training starts at step 1. [rank0]:[titan] 2025-06-16 16:24:23,544 - root - INFO - step: 1 loss: 8.1880 memory: 4.88GiB(6.16%) tps: 28 [rank0]:[titan] 2025-06-16 16:24:23,545 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2025-06-16 16:24:23,842 - root - INFO - step: 2 loss: 8.1610 memory: 4.90GiB(6.20%) tps: 13,785 [rank0]:[titan] 2025-06-16 16:24:24,135 - root - INFO - step: 3 loss: 8.0871 memory: 4.90GiB(6.20%) tps: 14,006 [rank0]:[titan] 2025-06-16 16:24:24,433 - root - INFO - step: 4 loss: 7.9516 memory: 4.90GiB(6.20%) tps: 13,770 [rank0]:[titan] 2025-06-16 16:24:24,727 - root - INFO - step: 5 loss: 7.8552 memory: 4.90GiB(6.20%) tps: 13,959 [rank0]:[titan] 2025-06-16 16:24:25,023 - root - INFO - step: 6 loss: 7.7732 memory: 4.90GiB(6.20%) tps: 13,859 [rank0]:[titan] 2025-06-16 16:24:25,324 - root - INFO - step: 7 loss: 7.6987 memory: 4.90GiB(6.20%) tps: 13,664 [rank0]:[titan] 2025-06-16 16:24:25,617 - root - INFO - step: 8 loss: 7.6779 memory: 4.90GiB(6.20%) tps: 13,985 [rank0]:[titan] 2025-06-16 16:24:25,911 - root - INFO - step: 9 loss: 7.6043 memory: 4.90GiB(6.20%) tps: 13,962 [rank0]:[titan] 2025-06-16 16:24:26,207 - root - INFO - step: 10 loss: 7.5778 memory: 4.90GiB(6.20%) tps: 13,891 """ Adopt new autoparallel API with meta-init model Allows reverting a lot of the hacks in the original integration that were caused by not creating a model obj in the train.py due to passing a model_fn builder to autop. Fixes to align with latest autoparallel Add inductor config knobs for comms optimizations to torchtitan Make inductor always run compile passes basically, this is an annoying workaround for debugging iteratively. 1- you run the model, it compiles, but something weird happens 2- you enable some logging or tlparse, rerun. but inductor decides not to run your pass anymore, its results are cached. since (2) has confused me horribly on more than one occasion, i just disable caching for now Drop hacky llama3_init_fn and use autop init_weights feature Relying on meta-pytorch/autoparallel#20, this lets us automatically apply a user's init_weights fn to the autoparallel model. Verified this works with `CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4 --training.dataset c4` ``` [rank0]:[titan] 2025-07-02 16:18:02,007 - root - INFO - Training starts at step 1. [rank0]:[titan] 2025-07-02 16:18:08,224 - root - INFO - step: 1 loss: 8.1848 memory: 1.09GiB(1.14%) tps: 77 tflops: 0.01 mfu: 0.00% [rank0]:[titan] 2025-07-02 16:18:08,224 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2025-07-02 16:18:08,310 - root - INFO - step: 2 loss: 8.1619 memory: 1.15GiB(1.21%) tps: 48,138 tflops: 3.46 mfu: 0.35 % [rank0]:[titan] 2025-07-02 16:18:08,356 - root - INFO - step: 3 loss: 8.1140 memory: 1.15GiB(1.21%) tps: 88,440 tflops: 6.36 mfu: 0.64 % [rank0]:[titan] 2025-07-02 16:18:08,406 - root - INFO - step: 4 loss: 8.0099 memory: 1.15GiB(1.21%) tps: 82,626 tflops: 5.94 mfu: 0.60 % [rank0]:[titan] 2025-07-02 16:18:08,457 - root - INFO - step: 5 loss: 7.8928 memory: 1.15GiB(1.21%) tps: 81,594 tflops: 5.87 mfu: 0.59 % [rank0]:[titan] 2025-07-02 16:18:08,508 - root - INFO - step: 6 loss: 7.7758 memory: 1.15GiB(1.21%) tps: 79,607 tflops: 5.72 mfu: 0.58 % [rank0]:[titan] 2025-07-02 16:18:08,559 - root - INFO - step: 7 loss: 7.6221 memory: 1.15GiB(1.21%) tps: 81,448 tflops: 5.86 mfu: 0.59 % [rank0]:[titan] 2025-07-02 16:18:08,611 - root - INFO - step: 8 loss: 7.5578 memory: 1.15GiB(1.21%) tps: 79,732 tflops: 5.73 mfu: 0.58 % [rank0]:[titan] 2025-07-02 16:18:08,659 - root - INFO - step: 9 loss: 7.3851 memory: 1.15GiB(1.21%) tps: 85,655 tflops: 6.16 mfu: 0.62 % [rank0]:[titan] 2025-07-02 16:18:08,709 - root - INFO - step: 10 loss: 7.3361 memory: 1.15GiB(1.21%) tps: 81,855 tflops: 5.89 mfu: 0.60 % [rank0]:[titan] 2025-07-02 16:18:08,709 - root - INFO - Sleeping 2 seconds for other ranks to complete ``` fix lint
TODO - try converting model params into fake tensors - figure out init fn - integrate torchtitan configs for DP/TP to control autop Hack an init_fn for llama3 and observe loss decreasing with autoparallel """ [rank0]:[titan] 2025-06-16 16:24:16,593 - root - INFO - Training starts at step 1. [rank0]:[titan] 2025-06-16 16:24:23,544 - root - INFO - step: 1 loss: 8.1880 memory: 4.88GiB(6.16%) tps: 28 [rank0]:[titan] 2025-06-16 16:24:23,545 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2025-06-16 16:24:23,842 - root - INFO - step: 2 loss: 8.1610 memory: 4.90GiB(6.20%) tps: 13,785 [rank0]:[titan] 2025-06-16 16:24:24,135 - root - INFO - step: 3 loss: 8.0871 memory: 4.90GiB(6.20%) tps: 14,006 [rank0]:[titan] 2025-06-16 16:24:24,433 - root - INFO - step: 4 loss: 7.9516 memory: 4.90GiB(6.20%) tps: 13,770 [rank0]:[titan] 2025-06-16 16:24:24,727 - root - INFO - step: 5 loss: 7.8552 memory: 4.90GiB(6.20%) tps: 13,959 [rank0]:[titan] 2025-06-16 16:24:25,023 - root - INFO - step: 6 loss: 7.7732 memory: 4.90GiB(6.20%) tps: 13,859 [rank0]:[titan] 2025-06-16 16:24:25,324 - root - INFO - step: 7 loss: 7.6987 memory: 4.90GiB(6.20%) tps: 13,664 [rank0]:[titan] 2025-06-16 16:24:25,617 - root - INFO - step: 8 loss: 7.6779 memory: 4.90GiB(6.20%) tps: 13,985 [rank0]:[titan] 2025-06-16 16:24:25,911 - root - INFO - step: 9 loss: 7.6043 memory: 4.90GiB(6.20%) tps: 13,962 [rank0]:[titan] 2025-06-16 16:24:26,207 - root - INFO - step: 10 loss: 7.5778 memory: 4.90GiB(6.20%) tps: 13,891 """ Adopt new autoparallel API with meta-init model Allows reverting a lot of the hacks in the original integration that were caused by not creating a model obj in the train.py due to passing a model_fn builder to autop. Fixes to align with latest autoparallel Add inductor config knobs for comms optimizations to torchtitan Make inductor always run compile passes basically, this is an annoying workaround for debugging iteratively. 1- you run the model, it compiles, but something weird happens 2- you enable some logging or tlparse, rerun. but inductor decides not to run your pass anymore, its results are cached. since (2) has confused me horribly on more than one occasion, i just disable caching for now Drop hacky llama3_init_fn and use autop init_weights feature Relying on meta-pytorch/autoparallel#20, this lets us automatically apply a user's init_weights fn to the autoparallel model. Verified this works with `CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4 --training.dataset c4` ``` [rank0]:[titan] 2025-07-02 16:18:02,007 - root - INFO - Training starts at step 1. [rank0]:[titan] 2025-07-02 16:18:08,224 - root - INFO - step: 1 loss: 8.1848 memory: 1.09GiB(1.14%) tps: 77 tflops: 0.01 mfu: 0.00% [rank0]:[titan] 2025-07-02 16:18:08,224 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2025-07-02 16:18:08,310 - root - INFO - step: 2 loss: 8.1619 memory: 1.15GiB(1.21%) tps: 48,138 tflops: 3.46 mfu: 0.35 % [rank0]:[titan] 2025-07-02 16:18:08,356 - root - INFO - step: 3 loss: 8.1140 memory: 1.15GiB(1.21%) tps: 88,440 tflops: 6.36 mfu: 0.64 % [rank0]:[titan] 2025-07-02 16:18:08,406 - root - INFO - step: 4 loss: 8.0099 memory: 1.15GiB(1.21%) tps: 82,626 tflops: 5.94 mfu: 0.60 % [rank0]:[titan] 2025-07-02 16:18:08,457 - root - INFO - step: 5 loss: 7.8928 memory: 1.15GiB(1.21%) tps: 81,594 tflops: 5.87 mfu: 0.59 % [rank0]:[titan] 2025-07-02 16:18:08,508 - root - INFO - step: 6 loss: 7.7758 memory: 1.15GiB(1.21%) tps: 79,607 tflops: 5.72 mfu: 0.58 % [rank0]:[titan] 2025-07-02 16:18:08,559 - root - INFO - step: 7 loss: 7.6221 memory: 1.15GiB(1.21%) tps: 81,448 tflops: 5.86 mfu: 0.59 % [rank0]:[titan] 2025-07-02 16:18:08,611 - root - INFO - step: 8 loss: 7.5578 memory: 1.15GiB(1.21%) tps: 79,732 tflops: 5.73 mfu: 0.58 % [rank0]:[titan] 2025-07-02 16:18:08,659 - root - INFO - step: 9 loss: 7.3851 memory: 1.15GiB(1.21%) tps: 85,655 tflops: 6.16 mfu: 0.62 % [rank0]:[titan] 2025-07-02 16:18:08,709 - root - INFO - step: 10 loss: 7.3361 memory: 1.15GiB(1.21%) tps: 81,855 tflops: 5.89 mfu: 0.60 % [rank0]:[titan] 2025-07-02 16:18:08,709 - root - INFO - Sleeping 2 seconds for other ranks to complete ``` fix lint [ghstack-poisoned]
TODO - try converting model params into fake tensors - figure out init fn - integrate torchtitan configs for DP/TP to control autop Hack an init_fn for llama3 and observe loss decreasing with autoparallel """ [rank0]:[titan] 2025-06-16 16:24:16,593 - root - INFO - Training starts at step 1. [rank0]:[titan] 2025-06-16 16:24:23,544 - root - INFO - step: 1 loss: 8.1880 memory: 4.88GiB(6.16%) tps: 28 [rank0]:[titan] 2025-06-16 16:24:23,545 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2025-06-16 16:24:23,842 - root - INFO - step: 2 loss: 8.1610 memory: 4.90GiB(6.20%) tps: 13,785 [rank0]:[titan] 2025-06-16 16:24:24,135 - root - INFO - step: 3 loss: 8.0871 memory: 4.90GiB(6.20%) tps: 14,006 [rank0]:[titan] 2025-06-16 16:24:24,433 - root - INFO - step: 4 loss: 7.9516 memory: 4.90GiB(6.20%) tps: 13,770 [rank0]:[titan] 2025-06-16 16:24:24,727 - root - INFO - step: 5 loss: 7.8552 memory: 4.90GiB(6.20%) tps: 13,959 [rank0]:[titan] 2025-06-16 16:24:25,023 - root - INFO - step: 6 loss: 7.7732 memory: 4.90GiB(6.20%) tps: 13,859 [rank0]:[titan] 2025-06-16 16:24:25,324 - root - INFO - step: 7 loss: 7.6987 memory: 4.90GiB(6.20%) tps: 13,664 [rank0]:[titan] 2025-06-16 16:24:25,617 - root - INFO - step: 8 loss: 7.6779 memory: 4.90GiB(6.20%) tps: 13,985 [rank0]:[titan] 2025-06-16 16:24:25,911 - root - INFO - step: 9 loss: 7.6043 memory: 4.90GiB(6.20%) tps: 13,962 [rank0]:[titan] 2025-06-16 16:24:26,207 - root - INFO - step: 10 loss: 7.5778 memory: 4.90GiB(6.20%) tps: 13,891 """ Adopt new autoparallel API with meta-init model Allows reverting a lot of the hacks in the original integration that were caused by not creating a model obj in the train.py due to passing a model_fn builder to autop. Fixes to align with latest autoparallel Add inductor config knobs for comms optimizations to torchtitan Make inductor always run compile passes basically, this is an annoying workaround for debugging iteratively. 1- you run the model, it compiles, but something weird happens 2- you enable some logging or tlparse, rerun. but inductor decides not to run your pass anymore, its results are cached. since (2) has confused me horribly on more than one occasion, i just disable caching for now Drop hacky llama3_init_fn and use autop init_weights feature Relying on meta-pytorch/autoparallel#20, this lets us automatically apply a user's init_weights fn to the autoparallel model. Verified this works with `CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4 --training.dataset c4` ``` [rank0]:[titan] 2025-07-02 16:18:02,007 - root - INFO - Training starts at step 1. [rank0]:[titan] 2025-07-02 16:18:08,224 - root - INFO - step: 1 loss: 8.1848 memory: 1.09GiB(1.14%) tps: 77 tflops: 0.01 mfu: 0.00% [rank0]:[titan] 2025-07-02 16:18:08,224 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2025-07-02 16:18:08,310 - root - INFO - step: 2 loss: 8.1619 memory: 1.15GiB(1.21%) tps: 48,138 tflops: 3.46 mfu: 0.35 % [rank0]:[titan] 2025-07-02 16:18:08,356 - root - INFO - step: 3 loss: 8.1140 memory: 1.15GiB(1.21%) tps: 88,440 tflops: 6.36 mfu: 0.64 % [rank0]:[titan] 2025-07-02 16:18:08,406 - root - INFO - step: 4 loss: 8.0099 memory: 1.15GiB(1.21%) tps: 82,626 tflops: 5.94 mfu: 0.60 % [rank0]:[titan] 2025-07-02 16:18:08,457 - root - INFO - step: 5 loss: 7.8928 memory: 1.15GiB(1.21%) tps: 81,594 tflops: 5.87 mfu: 0.59 % [rank0]:[titan] 2025-07-02 16:18:08,508 - root - INFO - step: 6 loss: 7.7758 memory: 1.15GiB(1.21%) tps: 79,607 tflops: 5.72 mfu: 0.58 % [rank0]:[titan] 2025-07-02 16:18:08,559 - root - INFO - step: 7 loss: 7.6221 memory: 1.15GiB(1.21%) tps: 81,448 tflops: 5.86 mfu: 0.59 % [rank0]:[titan] 2025-07-02 16:18:08,611 - root - INFO - step: 8 loss: 7.5578 memory: 1.15GiB(1.21%) tps: 79,732 tflops: 5.73 mfu: 0.58 % [rank0]:[titan] 2025-07-02 16:18:08,659 - root - INFO - step: 9 loss: 7.3851 memory: 1.15GiB(1.21%) tps: 85,655 tflops: 6.16 mfu: 0.62 % [rank0]:[titan] 2025-07-02 16:18:08,709 - root - INFO - step: 10 loss: 7.3361 memory: 1.15GiB(1.21%) tps: 81,855 tflops: 5.89 mfu: 0.60 % [rank0]:[titan] 2025-07-02 16:18:08,709 - root - INFO - Sleeping 2 seconds for other ranks to complete ``` fix lint [ghstack-poisoned]
If the user's model has a
model.init_weights()function (as in torchtitan), this PR allows us to use that function and run it on our sharded DTensor parameters inside ofAutoParallel.apply_placement. State of the PR:I confirmed that we are actually dispatching the init ops to our DTensor sharded params (
aten.normal_), and manually confirmed that Will's torchtitan integration doesn't break. We don't have any tests that use a real NCCL ProcessGroup today, I could try to write a test that asserts we get the same weight init as single-GPU if all params are replicated if people are interested.we still need to solve the state dict / FQN problem. Right now I just create two dicts that hold the same DTensor params/buffers, one with real FQN keys (which I need to reparametrize the original user model and do param init), and one with the "fake" FQN keys that already exist today, that we end up putting on the output
fx.GraphModule, since it is flattened and does not have proper FQNs.