Skip to content

Conversation

karthickai
Copy link

@karthickai karthickai commented Oct 14, 2025

Stacked PRs:


Fixes a numerical instability bug in float8 rowwise quantization that causes gradient norm explosion and training
divergence when using bfloat16 model precision. The issue affects the combination of bfloat16 models with float8
rowwise configuration.

Issue:

  • Gradient norms explode to nan and only affects: bfloat16 model + float8 rowwise quantization
  • Works fine: fp32 models + float8 rowwise quantization

The issue is amax_to_scale upscales amax to fp64 which creates precision error. Rowwise uses power-of-2 scaling which amplifies the error, leading to nan gradients.

Fix: converted bfloat16 tensors to float32 before amax computation instead of removing the float64 upcast, so existing flow is not affected.

To test: change this line https://github.com/pytorch/torchtitan/blob/248aca2cac54b4ce5d1cd460dd6ec187a8c276b3/torchtitan/train.py#L153
to model = self.train_spec.model_cls(model_args).to(torch.bfloat16) and run this command in torchtitan

TORCH_TRACE=logs NGPU=1 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --training.steps=10 --model.converters="quantize.linear.float8" --quantize.linear.float8.recipe-name="rowwise" --compile.enable --activation_checkpoint.mode="selective" --activation_checkpoint.selective_ac_option="op" --metrics.log_freq=1

before fix:
Screenshot 2025-10-13 at 9 00 30 PM

After fix
Screenshot 2025-10-13 at 8 58 52 PM

Fixes: pytorch/pytorch#150859

stack-info: PR: #3171, branch: karthickai/stack/1
Copy link

pytorch-bot bot commented Oct 14, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3171

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit 00141f4 with merge base 838dceb (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 14, 2025
@karthickai karthickai added float8 topic: not user facing Use this tag if you don't want this PR to show up in release notes labels Oct 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. float8 topic: not user facing Use this tag if you don't want this PR to show up in release notes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

RMS norm causes NaNs when used with torch.compile + float8 with rowwise scales

1 participant