Skip to content

Conversation

@bdhirsh
Copy link
Contributor

@bdhirsh bdhirsh commented Jun 26, 2025

If the user's model has a model.init_weights() function (as in torchtitan), this PR allows us to use that function and run it on our sharded DTensor parameters inside of AutoParallel.apply_placement. State of the PR:

  • I confirmed that we are actually dispatching the init ops to our DTensor sharded params (aten.normal_), and manually confirmed that Will's torchtitan integration doesn't break. We don't have any tests that use a real NCCL ProcessGroup today, I could try to write a test that asserts we get the same weight init as single-GPU if all params are replicated if people are interested.

  • we still need to solve the state dict / FQN problem. Right now I just create two dicts that hold the same DTensor params/buffers, one with real FQN keys (which I need to reparametrize the original user model and do param init), and one with the "fake" FQN keys that already exist today, that we end up putting on the output fx.GraphModule, since it is flattened and does not have proper FQNs.

@bdhirsh bdhirsh requested a review from fmassa June 26, 2025 16:02
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 26, 2025
@bdhirsh bdhirsh requested review from ezyang and wconstab June 26, 2025 16:02
Copy link
Contributor

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks!


def init_weights(self):
for lin in [self.wq, self.wk, self.wv, self.wo, self.w1, self.w2]:
breakpoint()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove?

sharded_buffers = try_convert_fake_to_real(sharded_buffers_no_fqns)
# Right now we require a convention that the user model provides an init_weights method,
# although we could snoop for other methods too.
if hasattr(self.model, "init_weights") is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question-

Given that we're specializing on the signature of orig_model.init_weights' being the init fn anyway, do you think it would be worth creating an init_weights` fn on self.parallel_model so that the initialization can be deferred until the user calls it?

This would solve one specific nit for my titan integration, namely that in titan the code goes like

model = ...
model = parallelize_fn(model) # auto_p in here
model.cuda()
model.init_weights() # <-- even if we initialize automatically inside autop, i'd still have to special case titan's train.py to not call it again here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yeah. If we are ok in the medium term with requiring a user-provided init_weights, and the usual flow is that users will call init_weights themselves anyway, then seems reasonable to let the user call this

I was about to send the above but I'm not actually sure how we would do what you described, unless we ensure that the returned parallel_module has the exact same nn.module hierarchy as the original user model. Since the user's init_weights function will probably have some recursive logic that walks the hierarchy.

We could potentially try to do that, but I'm worried it'll end up being a lot more complicated than the much simpler "keep graph flattened but make sure the FQNs are accurate" plan we have today. wdyt?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what I meant was basically to copy your code snippet into a lambda and stick that on parallel_mod as an init_weights method. Not to use the user's original init_weights method. Does this make sense? maybe i can try it and see if it works.

@wconstab
Copy link
Contributor

how did you validate this on torchtitan?

CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4 --training.dataset c4

for debugmodel i saw a huge loss value when using this version of init (and stays high).

[rank0]:[titan] 2025-06-26 17:12:11,974 - root - INFO - step:  1  loss: 51.1987  memory:  1.08GiB(1.14%)  tps: 80  tflops: 0.01  mfu: 0.00%

without your changes, the same command gives me a lower initial loss (and decreasing over time)

[rank0]:[titan] 2025-06-26 17:14:16,303 - root - INFO - step:  1  loss:  8.2055  memory:  1.08GiB(1.14%)  tps: 76  tflops: 0.01  mfu: 0.00%

@bdhirsh
Copy link
Contributor Author

bdhirsh commented Jul 1, 2025

@wconstab I just updated the PR to make init_weights a method on the output nn.Module returned by apply_placement.

I tried running my changes with a patched version of your torchtitan+autop integration, and I'm seeing "good" loss. Do you think there's something I'm missing in my repro?

[rank0]:[titan] 2025-07-01 15:38:30,649 - root - INFO - Training starts at step 1.
[rank0]:[titan] 2025-07-01 15:38:37,287 - root - INFO - step:  1  loss:  7.7213  memory:  1.09GiB(1.14%)  tps: 66  tflops: 0.00  mfu: 0.00%
[rank0]:[titan] 2025-07-01 15:38:37,288 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-07-01 15:38:37,508 - root - INFO - step:  2  loss:  7.7213  memory:  1.15GiB(1.21%)  tps: 18,567  tflops: 1.34  mfu: 0.13%
[rank0]:[titan] 2025-07-01 15:38:37,719 - root - INFO - step:  3  loss:  7.7213  memory:  1.15GiB(1.21%)  tps: 19,497  tflops: 1.40  mfu: 0.14%
[rank0]:[titan] 2025-07-01 15:38:37,931 - root - INFO - step:  4  loss:  7.7213  memory:  1.15GiB(1.21%)  tps: 19,344  tflops: 1.39  mfu: 0.14%
[rank0]:[titan] 2025-07-01 15:38:38,144 - root - INFO - step:  5  loss:  7.7213  memory:  1.15GiB(1.21%)  tps: 19,216  tflops: 1.38  mfu: 0.14%
[rank0]:[titan] 2025-07-01 15:38:38,357 - root - INFO - step:  6  loss:  7.7213  memory:  1.15GiB(1.21%)  tps: 19,343  tflops: 1.39  mfu: 0.14%
[rank0]:[titan] 2025-07-01 15:38:38,568 - root - INFO - step:  7  loss:  7.7213  memory:  1.15GiB(1.21%)  tps: 19,414  tflops: 1.40  mfu: 0.14%
[rank0]:[titan] 2025-07-01 15:38:38,781 - root - INFO - step:  8  loss:  7.7213  memory:  1.15GiB(1.21%)  tps: 19,286  tflops: 1.39  mfu: 0.14%
[rank0]:[titan] 2025-07-01 15:38:38,991 - root - INFO - step:  9  loss:  7.7213  memory:  1.15GiB(1.21%)  tps: 19,474  tflops: 1.40  mfu: 0.14%
[rank0]:[titan] 2025-07-01 15:38:39,201 - root - INFO - step: 10  loss:  7.7213  memory:  1.15GiB(1.21%)  tps: 19,514  tflops: 1.40  mfu: 0.14%

I tried tweaking your torchtitan branch to do this:

diff --git a/torchtitan/train.py b/torchtitan/train.py
index 50c5424..20a1ccd 100644
--- a/torchtitan/train.py
+++ b/torchtitan/train.py
@@ -339,9 +339,7 @@ class Trainer(torch.distributed.checkpoint.stateful.Stateful):

             model.to_empty(device=init_device)
             with torch.no_grad():
-                # TODO(whc) make model.init_weights work with autoparallel
-                llama3_autoparallel_init_fn(model)
-                # model.init_weights(buffer_device=buffer_device)
+                model.init_weights(buffer_device=buffer_device)
             model.train()

             self.model_parts = [model]

@bdhirsh bdhirsh force-pushed the use_init_weights branch from bc4b8a0 to 03b4c24 Compare July 1, 2025 22:48
Comment on lines +432 to +435
# assign an init_weights method onto the output mod.
# all it does is sneakily run the original user mod's init_weights method,
# but with our new DTensor sharded params attached to the user module.
self.parallel_model.init_weights = init_weights
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm ok with this for now, but then we would need to clearly specify the contract of what methods are propagated from the original model vs not.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. if we want to be extreme, we could make this a strict requirement and error out if the user's mod doesn't have a init_weights method. Or maybe once we get further along adoption-wise we should just write some docs that clearly spell out the restrictions / user requirements, this being one of them?

@wconstab
Copy link
Contributor

wconstab commented Jul 2, 2025

I'm trying the latest PR again now

I tried running my changes with a patched version of your torchtitan+autop integration, and I'm seeing "good" loss. Do you think there's something I'm missing in my repro?

note that in your paste above, the loss value is exactly the same on each step. I don't think this is OK. However I did previously see totally exploding losses, which I no longer see. Instead, it almost seems like the parameters are "frozen at some random value forever".

Here is what I get on debugmodel without your change:
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4 --training.dataset c4

[rank0]:[titan] 2025-07-02 11:19:23,802 - root - INFO - Training starts at step 1.
[rank0]:[titan] 2025-07-02 11:19:28,472 - root - INFO - step:  1  loss:  8.2162  memory:  1.09GiB(1.14%)  tps: 79  tflops: 0.01  mfu: 0.00%
[rank0]:[titan] 2025-07-02 11:19:28,473 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-07-02 11:19:28,524 - root - INFO - step:  2  loss:  8.1852  memory:  1.15GiB(1.21%)  tps: 79,303  tflops: 5.70  mfu: 0.58%
[rank0]:[titan] 2025-07-02 11:19:28,573 - root - INFO - step:  3  loss:  8.1031  memory:  1.15GiB(1.21%)  tps: 85,385  tflops: 6.14  mfu: 0.62%
[rank0]:[titan] 2025-07-02 11:19:28,623 - root - INFO - step:  4  loss:  8.0311  memory:  1.15GiB(1.21%)  tps: 81,445  tflops: 5.86  mfu: 0.59%
[rank0]:[titan] 2025-07-02 11:19:28,675 - root - INFO - step:  5  loss:  7.9235  memory:  1.15GiB(1.21%)  tps: 79,995  tflops: 5.75  mfu: 0.58%
[rank0]:[titan] 2025-07-02 11:19:28,725 - root - INFO - step:  6  loss:  7.8281  memory:  1.15GiB(1.21%)  tps: 82,036  tflops: 5.90  mfu: 0.60%
[rank0]:[titan] 2025-07-02 11:19:28,774 - root - INFO - step:  7  loss:  7.6741  memory:  1.15GiB(1.21%)  tps: 84,234  tflops: 6.06  mfu: 0.61%
[rank0]:[titan] 2025-07-02 11:19:28,823 - root - INFO - step:  8  loss:  7.6253  memory:  1.15GiB(1.21%)  tps: 83,043  tflops: 5.97  mfu: 0.60%
[rank0]:[titan] 2025-07-02 11:19:28,872 - root - INFO - step:  9  loss:  7.4764  memory:  1.15GiB(1.21%)  tps: 84,999  tflops: 6.11  mfu: 0.62%
[rank0]:[titan] 2025-07-02 11:19:28,920 - root - INFO - step: 10  loss:  7.4103  memory:  1.15GiB(1.21%)  tps: 84,756  tflops: 6.09  mfu: 0.62%

and here's what I get with this PR at 03b4c24 and this change to torchtitan train.py

            with torch.no_grad():
                # TODO(whc) make model.init_weights work with autoparallel
                # llama3_autoparallel_init_fn(model)
                model.init_weights(buffer_device=buffer_device)
            model.train()
[rank0]:[titan] 2025-07-02 11:21:52,421 - root - INFO - Training starts at step 1.
[rank0]:[titan] 2025-07-02 11:21:59,324 - root - INFO - step:  1  loss:  7.7213  memory:  1.09GiB(1.14%)  tps: 78  tflops: 0.01  mfu: 0.00%
[rank0]:[titan] 2025-07-02 11:21:59,325 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-07-02 11:21:59,375 - root - INFO - step:  2  loss:  7.7213  memory:  1.15GiB(1.21%)  tps: 81,350  tflops: 5.85  mfu: 0.59%
[rank0]:[titan] 2025-07-02 11:21:59,421 - root - INFO - step:  3  loss:  7.7213  memory:  1.15GiB(1.21%)  tps: 89,207  tflops: 6.41  mfu: 0.65%
[rank0]:[titan] 2025-07-02 11:21:59,469 - root - INFO - step:  4  loss:  7.7213  memory:  1.15GiB(1.21%)  tps: 85,576  tflops: 6.15  mfu: 0.62%
[rank0]:[titan] 2025-07-02 11:21:59,520 - root - INFO - step:  5  loss:  7.7213  memory:  1.15GiB(1.21%)  tps: 82,083  tflops: 5.90  mfu: 0.60%
[rank0]:[titan] 2025-07-02 11:21:59,571 - root - INFO - step:  6  loss:  7.7213  memory:  1.15GiB(1.21%)  tps: 80,706  tflops: 5.80  mfu: 0.59%
[rank0]:[titan] 2025-07-02 11:21:59,624 - root - INFO - step:  7  loss:  7.7213  memory:  1.15GiB(1.21%)  tps: 77,181  tflops: 5.55  mfu: 0.56%
[rank0]:[titan] 2025-07-02 11:21:59,675 - root - INFO - step:  8  loss:  7.7213  memory:  1.15GiB(1.21%)  tps: 80,793  tflops: 5.81  mfu: 0.59%
[rank0]:[titan] 2025-07-02 11:21:59,724 - root - INFO - step:  9  loss:  7.7213  memory:  1.15GiB(1.21%)  tps: 84,100  tflops: 6.05  mfu: 0.61%
[rank0]:[titan] 2025-07-02 11:21:59,775 - root - INFO - step: 10  loss:  7.7213  memory:  1.15GiB(1.21%)  tps: 81,153  tflops: 5.84  mfu: 0.59%

@bdhirsh
Copy link
Contributor Author

bdhirsh commented Jul 2, 2025

ok update - after a small tweak to this PR, I realized the other problem is that torchtitan uses model.to_empty() to initialize (sharded params).

This ends up playing poorly with the param init, because:

(1) optimize_placements() has already allocated real DTensor params for you on the proper device, so this is redundant
(2) the init_fn we are putting on the new module closes over these allocated DTensor params. But torchtitan ends up throwing those away and inserting brand new (empty) params into the state dict, so when you run init_weights() we end running initialization on the DTensor params that auto-parallel constructed, not the ones that titan constructs with its to_empty() call.

With my updated PR, I now get what looks like reasonable loss with this titan change:

diff --git a/torchtitan/train.py b/torchtitan/train.py
index 50c5424..2f17341 100644
--- a/torchtitan/train.py
+++ b/torchtitan/train.py
@@ -337,11 +337,8 @@ class Trainer(torch.distributed.checkpoint.stateful.Stateful):
                 model, world_mesh, parallel_dims, job_config
             )

-            model.to_empty(device=init_device)
             with torch.no_grad():
-                # TODO(whc) make model.init_weights work with autoparallel
-                llama3_autoparallel_init_fn(model)
-                # model.init_weights(buffer_device=buffer_device)
+                model.init_weights(buffer_device=buffer_device)
             model.train()

             self.model_parts = [model]

loss:

[rank0]:[titan] 2025-07-02 12:34:57,393 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-07-02 12:34:57,609 - root - INFO - step:  2  loss:  8.1936  memory:  1.15GiB(1.21%)  tps: 18,988  tflops: 1.37  mfu: 0.14%
[rank0]:[titan] 2025-07-02 12:34:57,819 - root - INFO - step:  3  loss:  8.1413  memory:  1.15GiB(1.21%)  tps: 19,487  tflops: 1.40  mfu: 0.14%
[rank0]:[titan] 2025-07-02 12:34:58,030 - root - INFO - step:  4  loss:  8.0574  memory:  1.15GiB(1.21%)  tps: 19,495  tflops: 1.40  mfu: 0.14%
[rank0]:[titan] 2025-07-02 12:34:58,242 - root - INFO - step:  5  loss:  7.9381  memory:  1.15GiB(1.21%)  tps: 19,306  tflops: 1.39  mfu: 0.14%
[rank0]:[titan] 2025-07-02 12:34:58,454 - root - INFO - step:  6  loss:  7.8293  memory:  1.15GiB(1.21%)  tps: 19,351  tflops: 1.39  mfu: 0.14%
[rank0]:[titan] 2025-07-02 12:34:58,667 - root - INFO - step:  7  loss:  7.7050  memory:  1.15GiB(1.21%)  tps: 19,258  tflops: 1.38  mfu: 0.14%
[rank0]:[titan] 2025-07-02 12:34:58,880 - root - INFO - step:  8  loss:  7.6321  memory:  1.15GiB(1.21%)  tps: 19,246  tflops: 1.38  mfu: 0.14%
[rank0]:[titan] 2025-07-02 12:34:59,093 - root - INFO - step:  9  loss:  7.4880  memory:  1.15GiB(1.21%)  tps: 19,320  tflops: 1.39  mfu: 0.14%
[rank0]:[titan] 2025-07-02 12:34:59,306 - root - INFO - step: 10  loss:  7.4143  memory:  1.15GiB(1.21%)  tps: 19,262  tflops: 1.39  mfu: 0.14%

@wconstab
Copy link
Contributor

wconstab commented Jul 2, 2025

this works for me now.
Thanks!
Though, i am debating whether we want to make 'to_empty' behave more like the non-autoparallel case. I think we do, but we can address that in another PR if we do.

@bdhirsh
Copy link
Contributor Author

bdhirsh commented Jul 2, 2025

I ended up updating so that:

(1) to_empty() from the user works, our init_weights fn now just closes over the state dict so we can grab whatever the user dumps in there and initialize it

(2) i also had to tweak autoparallel to return a state dict containing DTensor(meta_tensor) instead of DTensor(FakeTensor), since with fake tensor, the nn.Module.to_empty() doesn't seem to do what we want (it just allocates more FakeTensors)

@bdhirsh
Copy link
Contributor Author

bdhirsh commented Jul 2, 2025

and I confirmed reasonable loss on titan with the debug model with this tweak:

diff --git a/torchtitan/train.py b/torchtitan/train.py
index 50c5424..20a1ccd 100644
--- a/torchtitan/train.py
+++ b/torchtitan/train.py
@@ -339,9 +339,7 @@ class Trainer(torch.distributed.checkpoint.stateful.Stateful):

             model.to_empty(device=init_device)
             with torch.no_grad():
-                # TODO(whc) make model.init_weights work with autoparallel
-                llama3_autoparallel_init_fn(model)
-                # model.init_weights(buffer_device=buffer_device)
+                model.init_weights(buffer_device=buffer_device)
             model.train()

             self.model_parts = [model]

loss

[rank0]:[titan] 2025-07-02 15:36:56,880 - root - INFO - step:  1  loss:  8.2232  memory:  1.09GiB(1.14%)  tps: 66  tflops: 0.00  mfu: 0.00%
[rank0]:[titan] 2025-07-02 15:36:56,880 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-07-02 15:36:57,097 - root - INFO - step:  2  loss:  8.1821  memory:  1.15GiB(1.21%)  tps: 18,896  tflops: 1.36  mfu: 0.14%
[rank0]:[titan] 2025-07-02 15:36:57,309 - root - INFO - step:  3  loss:  8.1249  memory:  1.15GiB(1.21%)  tps: 19,315  tflops: 1.39  mfu: 0.14%
[rank0]:[titan] 2025-07-02 15:36:57,521 - root - INFO - step:  4  loss:  8.0224  memory:  1.15GiB(1.21%)  tps: 19,376  tflops: 1.39  mfu: 0.14%
[rank0]:[titan] 2025-07-02 15:36:57,732 - root - INFO - step:  5  loss:  7.9030  memory:  1.15GiB(1.21%)  tps: 19,456  tflops: 1.40  mfu: 0.14%
[rank0]:[titan] 2025-07-02 15:36:57,943 - root - INFO - step:  6  loss:  7.7718  memory:  1.15GiB(1.21%)  tps: 19,463  tflops: 1.40  mfu: 0.14%
[rank0]:[titan] 2025-07-02 15:36:58,155 - root - INFO - step:  7  loss:  7.6099  memory:  1.15GiB(1.21%)  tps: 19,361  tflops: 1.39  mfu: 0.14%
[rank0]:[titan] 2025-07-02 15:36:58,368 - root - INFO - step:  8  loss:  7.5154  memory:  1.15GiB(1.21%)  tps: 19,235  tflops: 1.38  mfu: 0.14%
[rank0]:[titan] 2025-07-02 15:36:58,578 - root - INFO - step:  9  loss:  7.3545  memory:  1.15GiB(1.21%)  tps: 19,581  tflops: 1.41  mfu: 0.14%
[rank0]:[titan] 2025-07-02 15:36:58,788 - root - INFO - step: 10  loss:  7.2924  memory:  1.15GiB(1.21%)  tps: 19,540  tflops: 1.40  mfu: 0.14%

@wconstab
Copy link
Contributor

wconstab commented Jul 2, 2025

testing locally and will merge this and torchtitan/autoparallel changes once verified. thanks!


# Right now we require a convention that the user model provides an init_weights method,
# although we could snoop for other methods too.
if hasattr(self.model, "init_weights") is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if hasattr(self.model, "init_weights") is not None:
shouldn't this be just

if hasattr(self.model, "init_weights"):?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pushed a fix.

@wconstab wconstab merged commit 1ca0f1d into main Jul 2, 2025
4 checks passed
wconstab added a commit to pytorch/torchtitan that referenced this pull request Jul 2, 2025
Relying on meta-pytorch/autoparallel#20, this
lets us automatically apply a user's init_weights fn to the autoparallel
model.

Verified this works with

`CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4 --training.dataset c4`

```
[rank0]:[titan] 2025-07-02 16:18:02,007 - root - INFO - Training starts at step 1.
[rank0]:[titan] 2025-07-02 16:18:08,224 - root - INFO - step:  1  loss:  8.1848  memory:  1.09GiB(1.14%)  tps: 77  tflops: 0.01  mfu: 0.00%
[rank0]:[titan] 2025-07-02 16:18:08,224 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-07-02 16:18:08,310 - root - INFO - step:  2  loss:  8.1619  memory:  1.15GiB(1.21%)  tps: 48,138  tflops: 3.46  mfu: 0.35
%
[rank0]:[titan] 2025-07-02 16:18:08,356 - root - INFO - step:  3  loss:  8.1140  memory:  1.15GiB(1.21%)  tps: 88,440  tflops: 6.36  mfu: 0.64
%
[rank0]:[titan] 2025-07-02 16:18:08,406 - root - INFO - step:  4  loss:  8.0099  memory:  1.15GiB(1.21%)  tps: 82,626  tflops: 5.94  mfu: 0.60
%
[rank0]:[titan] 2025-07-02 16:18:08,457 - root - INFO - step:  5  loss:  7.8928  memory:  1.15GiB(1.21%)  tps: 81,594  tflops: 5.87  mfu: 0.59
%
[rank0]:[titan] 2025-07-02 16:18:08,508 - root - INFO - step:  6  loss:  7.7758  memory:  1.15GiB(1.21%)  tps: 79,607  tflops: 5.72  mfu: 0.58
%
[rank0]:[titan] 2025-07-02 16:18:08,559 - root - INFO - step:  7  loss:  7.6221  memory:  1.15GiB(1.21%)  tps: 81,448  tflops: 5.86  mfu: 0.59
%
[rank0]:[titan] 2025-07-02 16:18:08,611 - root - INFO - step:  8  loss:  7.5578  memory:  1.15GiB(1.21%)  tps: 79,732  tflops: 5.73  mfu: 0.58
%
[rank0]:[titan] 2025-07-02 16:18:08,659 - root - INFO - step:  9  loss:  7.3851  memory:  1.15GiB(1.21%)  tps: 85,655  tflops: 6.16  mfu: 0.62
%
[rank0]:[titan] 2025-07-02 16:18:08,709 - root - INFO - step: 10  loss:  7.3361  memory:  1.15GiB(1.21%)  tps: 81,855  tflops: 5.89  mfu: 0.60
%
[rank0]:[titan] 2025-07-02 16:18:08,709 - root - INFO - Sleeping 2 seconds for other ranks to complete
```
@fmassa fmassa deleted the use_init_weights branch July 3, 2025 08:29
wconstab added a commit to pytorch/torchtitan that referenced this pull request Jul 11, 2025
TODO
- try converting model params into fake tensors
- figure out init fn
- integrate torchtitan configs for DP/TP to control autop

Hack an init_fn for llama3 and observe loss decreasing with autoparallel

"""
[rank0]:[titan] 2025-06-16 16:24:16,593 - root - INFO - Training starts at step 1.
[rank0]:[titan] 2025-06-16 16:24:23,544 - root - INFO - step:  1  loss:  8.1880  memory:  4.88GiB(6.16%)  tps: 28
[rank0]:[titan] 2025-06-16 16:24:23,545 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-06-16 16:24:23,842 - root - INFO - step:  2  loss:  8.1610  memory:  4.90GiB(6.20%)  tps: 13,785
[rank0]:[titan] 2025-06-16 16:24:24,135 - root - INFO - step:  3  loss:  8.0871  memory:  4.90GiB(6.20%)  tps: 14,006
[rank0]:[titan] 2025-06-16 16:24:24,433 - root - INFO - step:  4  loss:  7.9516  memory:  4.90GiB(6.20%)  tps: 13,770
[rank0]:[titan] 2025-06-16 16:24:24,727 - root - INFO - step:  5  loss:  7.8552  memory:  4.90GiB(6.20%)  tps: 13,959
[rank0]:[titan] 2025-06-16 16:24:25,023 - root - INFO - step:  6  loss:  7.7732  memory:  4.90GiB(6.20%)  tps: 13,859
[rank0]:[titan] 2025-06-16 16:24:25,324 - root - INFO - step:  7  loss:  7.6987  memory:  4.90GiB(6.20%)  tps: 13,664
[rank0]:[titan] 2025-06-16 16:24:25,617 - root - INFO - step:  8  loss:  7.6779  memory:  4.90GiB(6.20%)  tps: 13,985
[rank0]:[titan] 2025-06-16 16:24:25,911 - root - INFO - step:  9  loss:  7.6043  memory:  4.90GiB(6.20%)  tps: 13,962
[rank0]:[titan] 2025-06-16 16:24:26,207 - root - INFO - step: 10  loss:  7.5778  memory:  4.90GiB(6.20%)  tps: 13,891
"""

Adopt new autoparallel API with meta-init model

Allows reverting a lot of the hacks in the original integration that
were caused by not creating a model obj in the train.py due to passing a
model_fn builder to autop.

Fixes to align with latest autoparallel

Add inductor config knobs for comms optimizations to torchtitan

Make inductor always run compile passes

basically, this is an annoying workaround for debugging iteratively.

1- you run the model, it compiles, but something weird happens
2- you enable some logging or tlparse, rerun. but inductor decides not
to run your pass anymore, its results are cached.

since (2) has confused me horribly on more than one occasion, i just
disable caching for now

Drop hacky llama3_init_fn and use autop init_weights feature

Relying on meta-pytorch/autoparallel#20, this
lets us automatically apply a user's init_weights fn to the autoparallel
model.

Verified this works with

`CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4 --training.dataset c4`

```
[rank0]:[titan] 2025-07-02 16:18:02,007 - root - INFO - Training starts at step 1.
[rank0]:[titan] 2025-07-02 16:18:08,224 - root - INFO - step:  1  loss:  8.1848  memory:  1.09GiB(1.14%)  tps: 77  tflops: 0.01  mfu: 0.00%
[rank0]:[titan] 2025-07-02 16:18:08,224 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-07-02 16:18:08,310 - root - INFO - step:  2  loss:  8.1619  memory:  1.15GiB(1.21%)  tps: 48,138  tflops: 3.46  mfu: 0.35
%
[rank0]:[titan] 2025-07-02 16:18:08,356 - root - INFO - step:  3  loss:  8.1140  memory:  1.15GiB(1.21%)  tps: 88,440  tflops: 6.36  mfu: 0.64
%
[rank0]:[titan] 2025-07-02 16:18:08,406 - root - INFO - step:  4  loss:  8.0099  memory:  1.15GiB(1.21%)  tps: 82,626  tflops: 5.94  mfu: 0.60
%
[rank0]:[titan] 2025-07-02 16:18:08,457 - root - INFO - step:  5  loss:  7.8928  memory:  1.15GiB(1.21%)  tps: 81,594  tflops: 5.87  mfu: 0.59
%
[rank0]:[titan] 2025-07-02 16:18:08,508 - root - INFO - step:  6  loss:  7.7758  memory:  1.15GiB(1.21%)  tps: 79,607  tflops: 5.72  mfu: 0.58
%
[rank0]:[titan] 2025-07-02 16:18:08,559 - root - INFO - step:  7  loss:  7.6221  memory:  1.15GiB(1.21%)  tps: 81,448  tflops: 5.86  mfu: 0.59
%
[rank0]:[titan] 2025-07-02 16:18:08,611 - root - INFO - step:  8  loss:  7.5578  memory:  1.15GiB(1.21%)  tps: 79,732  tflops: 5.73  mfu: 0.58
%
[rank0]:[titan] 2025-07-02 16:18:08,659 - root - INFO - step:  9  loss:  7.3851  memory:  1.15GiB(1.21%)  tps: 85,655  tflops: 6.16  mfu: 0.62
%
[rank0]:[titan] 2025-07-02 16:18:08,709 - root - INFO - step: 10  loss:  7.3361  memory:  1.15GiB(1.21%)  tps: 81,855  tflops: 5.89  mfu: 0.60
%
[rank0]:[titan] 2025-07-02 16:18:08,709 - root - INFO - Sleeping 2 seconds for other ranks to complete
```

fix lint
IvanKobzarev added a commit to pytorch/torchtitan that referenced this pull request Jul 25, 2025
TODO
- try converting model params into fake tensors
- figure out init fn
- integrate torchtitan configs for DP/TP to control autop

Hack an init_fn for llama3 and observe loss decreasing with autoparallel

"""
[rank0]:[titan] 2025-06-16 16:24:16,593 - root - INFO - Training starts at step 1.
[rank0]:[titan] 2025-06-16 16:24:23,544 - root - INFO - step:  1  loss:  8.1880  memory:  4.88GiB(6.16%)  tps: 28
[rank0]:[titan] 2025-06-16 16:24:23,545 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-06-16 16:24:23,842 - root - INFO - step:  2  loss:  8.1610  memory:  4.90GiB(6.20%)  tps: 13,785
[rank0]:[titan] 2025-06-16 16:24:24,135 - root - INFO - step:  3  loss:  8.0871  memory:  4.90GiB(6.20%)  tps: 14,006
[rank0]:[titan] 2025-06-16 16:24:24,433 - root - INFO - step:  4  loss:  7.9516  memory:  4.90GiB(6.20%)  tps: 13,770
[rank0]:[titan] 2025-06-16 16:24:24,727 - root - INFO - step:  5  loss:  7.8552  memory:  4.90GiB(6.20%)  tps: 13,959
[rank0]:[titan] 2025-06-16 16:24:25,023 - root - INFO - step:  6  loss:  7.7732  memory:  4.90GiB(6.20%)  tps: 13,859
[rank0]:[titan] 2025-06-16 16:24:25,324 - root - INFO - step:  7  loss:  7.6987  memory:  4.90GiB(6.20%)  tps: 13,664
[rank0]:[titan] 2025-06-16 16:24:25,617 - root - INFO - step:  8  loss:  7.6779  memory:  4.90GiB(6.20%)  tps: 13,985
[rank0]:[titan] 2025-06-16 16:24:25,911 - root - INFO - step:  9  loss:  7.6043  memory:  4.90GiB(6.20%)  tps: 13,962
[rank0]:[titan] 2025-06-16 16:24:26,207 - root - INFO - step: 10  loss:  7.5778  memory:  4.90GiB(6.20%)  tps: 13,891
"""

Adopt new autoparallel API with meta-init model

Allows reverting a lot of the hacks in the original integration that
were caused by not creating a model obj in the train.py due to passing a
model_fn builder to autop.

Fixes to align with latest autoparallel

Add inductor config knobs for comms optimizations to torchtitan

Make inductor always run compile passes

basically, this is an annoying workaround for debugging iteratively.

1- you run the model, it compiles, but something weird happens
2- you enable some logging or tlparse, rerun. but inductor decides not
to run your pass anymore, its results are cached.

since (2) has confused me horribly on more than one occasion, i just
disable caching for now

Drop hacky llama3_init_fn and use autop init_weights feature

Relying on meta-pytorch/autoparallel#20, this
lets us automatically apply a user's init_weights fn to the autoparallel
model.

Verified this works with

`CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4 --training.dataset c4`

```
[rank0]:[titan] 2025-07-02 16:18:02,007 - root - INFO - Training starts at step 1.
[rank0]:[titan] 2025-07-02 16:18:08,224 - root - INFO - step:  1  loss:  8.1848  memory:  1.09GiB(1.14%)  tps: 77  tflops: 0.01  mfu: 0.00%
[rank0]:[titan] 2025-07-02 16:18:08,224 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-07-02 16:18:08,310 - root - INFO - step:  2  loss:  8.1619  memory:  1.15GiB(1.21%)  tps: 48,138  tflops: 3.46  mfu: 0.35
%
[rank0]:[titan] 2025-07-02 16:18:08,356 - root - INFO - step:  3  loss:  8.1140  memory:  1.15GiB(1.21%)  tps: 88,440  tflops: 6.36  mfu: 0.64
%
[rank0]:[titan] 2025-07-02 16:18:08,406 - root - INFO - step:  4  loss:  8.0099  memory:  1.15GiB(1.21%)  tps: 82,626  tflops: 5.94  mfu: 0.60
%
[rank0]:[titan] 2025-07-02 16:18:08,457 - root - INFO - step:  5  loss:  7.8928  memory:  1.15GiB(1.21%)  tps: 81,594  tflops: 5.87  mfu: 0.59
%
[rank0]:[titan] 2025-07-02 16:18:08,508 - root - INFO - step:  6  loss:  7.7758  memory:  1.15GiB(1.21%)  tps: 79,607  tflops: 5.72  mfu: 0.58
%
[rank0]:[titan] 2025-07-02 16:18:08,559 - root - INFO - step:  7  loss:  7.6221  memory:  1.15GiB(1.21%)  tps: 81,448  tflops: 5.86  mfu: 0.59
%
[rank0]:[titan] 2025-07-02 16:18:08,611 - root - INFO - step:  8  loss:  7.5578  memory:  1.15GiB(1.21%)  tps: 79,732  tflops: 5.73  mfu: 0.58
%
[rank0]:[titan] 2025-07-02 16:18:08,659 - root - INFO - step:  9  loss:  7.3851  memory:  1.15GiB(1.21%)  tps: 85,655  tflops: 6.16  mfu: 0.62
%
[rank0]:[titan] 2025-07-02 16:18:08,709 - root - INFO - step: 10  loss:  7.3361  memory:  1.15GiB(1.21%)  tps: 81,855  tflops: 5.89  mfu: 0.60
%
[rank0]:[titan] 2025-07-02 16:18:08,709 - root - INFO - Sleeping 2 seconds for other ranks to complete
```

fix lint

[ghstack-poisoned]
IvanKobzarev added a commit to pytorch/torchtitan that referenced this pull request Jul 25, 2025
TODO
- try converting model params into fake tensors
- figure out init fn
- integrate torchtitan configs for DP/TP to control autop

Hack an init_fn for llama3 and observe loss decreasing with autoparallel

"""
[rank0]:[titan] 2025-06-16 16:24:16,593 - root - INFO - Training starts at step 1.
[rank0]:[titan] 2025-06-16 16:24:23,544 - root - INFO - step:  1  loss:  8.1880  memory:  4.88GiB(6.16%)  tps: 28
[rank0]:[titan] 2025-06-16 16:24:23,545 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-06-16 16:24:23,842 - root - INFO - step:  2  loss:  8.1610  memory:  4.90GiB(6.20%)  tps: 13,785
[rank0]:[titan] 2025-06-16 16:24:24,135 - root - INFO - step:  3  loss:  8.0871  memory:  4.90GiB(6.20%)  tps: 14,006
[rank0]:[titan] 2025-06-16 16:24:24,433 - root - INFO - step:  4  loss:  7.9516  memory:  4.90GiB(6.20%)  tps: 13,770
[rank0]:[titan] 2025-06-16 16:24:24,727 - root - INFO - step:  5  loss:  7.8552  memory:  4.90GiB(6.20%)  tps: 13,959
[rank0]:[titan] 2025-06-16 16:24:25,023 - root - INFO - step:  6  loss:  7.7732  memory:  4.90GiB(6.20%)  tps: 13,859
[rank0]:[titan] 2025-06-16 16:24:25,324 - root - INFO - step:  7  loss:  7.6987  memory:  4.90GiB(6.20%)  tps: 13,664
[rank0]:[titan] 2025-06-16 16:24:25,617 - root - INFO - step:  8  loss:  7.6779  memory:  4.90GiB(6.20%)  tps: 13,985
[rank0]:[titan] 2025-06-16 16:24:25,911 - root - INFO - step:  9  loss:  7.6043  memory:  4.90GiB(6.20%)  tps: 13,962
[rank0]:[titan] 2025-06-16 16:24:26,207 - root - INFO - step: 10  loss:  7.5778  memory:  4.90GiB(6.20%)  tps: 13,891
"""

Adopt new autoparallel API with meta-init model

Allows reverting a lot of the hacks in the original integration that
were caused by not creating a model obj in the train.py due to passing a
model_fn builder to autop.

Fixes to align with latest autoparallel

Add inductor config knobs for comms optimizations to torchtitan

Make inductor always run compile passes

basically, this is an annoying workaround for debugging iteratively.

1- you run the model, it compiles, but something weird happens
2- you enable some logging or tlparse, rerun. but inductor decides not
to run your pass anymore, its results are cached.

since (2) has confused me horribly on more than one occasion, i just
disable caching for now

Drop hacky llama3_init_fn and use autop init_weights feature

Relying on meta-pytorch/autoparallel#20, this
lets us automatically apply a user's init_weights fn to the autoparallel
model.

Verified this works with

`CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4 --training.dataset c4`

```
[rank0]:[titan] 2025-07-02 16:18:02,007 - root - INFO - Training starts at step 1.
[rank0]:[titan] 2025-07-02 16:18:08,224 - root - INFO - step:  1  loss:  8.1848  memory:  1.09GiB(1.14%)  tps: 77  tflops: 0.01  mfu: 0.00%
[rank0]:[titan] 2025-07-02 16:18:08,224 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-07-02 16:18:08,310 - root - INFO - step:  2  loss:  8.1619  memory:  1.15GiB(1.21%)  tps: 48,138  tflops: 3.46  mfu: 0.35
%
[rank0]:[titan] 2025-07-02 16:18:08,356 - root - INFO - step:  3  loss:  8.1140  memory:  1.15GiB(1.21%)  tps: 88,440  tflops: 6.36  mfu: 0.64
%
[rank0]:[titan] 2025-07-02 16:18:08,406 - root - INFO - step:  4  loss:  8.0099  memory:  1.15GiB(1.21%)  tps: 82,626  tflops: 5.94  mfu: 0.60
%
[rank0]:[titan] 2025-07-02 16:18:08,457 - root - INFO - step:  5  loss:  7.8928  memory:  1.15GiB(1.21%)  tps: 81,594  tflops: 5.87  mfu: 0.59
%
[rank0]:[titan] 2025-07-02 16:18:08,508 - root - INFO - step:  6  loss:  7.7758  memory:  1.15GiB(1.21%)  tps: 79,607  tflops: 5.72  mfu: 0.58
%
[rank0]:[titan] 2025-07-02 16:18:08,559 - root - INFO - step:  7  loss:  7.6221  memory:  1.15GiB(1.21%)  tps: 81,448  tflops: 5.86  mfu: 0.59
%
[rank0]:[titan] 2025-07-02 16:18:08,611 - root - INFO - step:  8  loss:  7.5578  memory:  1.15GiB(1.21%)  tps: 79,732  tflops: 5.73  mfu: 0.58
%
[rank0]:[titan] 2025-07-02 16:18:08,659 - root - INFO - step:  9  loss:  7.3851  memory:  1.15GiB(1.21%)  tps: 85,655  tflops: 6.16  mfu: 0.62
%
[rank0]:[titan] 2025-07-02 16:18:08,709 - root - INFO - step: 10  loss:  7.3361  memory:  1.15GiB(1.21%)  tps: 81,855  tflops: 5.89  mfu: 0.60
%
[rank0]:[titan] 2025-07-02 16:18:08,709 - root - INFO - Sleeping 2 seconds for other ranks to complete
```

fix lint

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants