Fix gradient explosion in bfloat16+float8 rowwise quantization #3171
+3
−0
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Stacked PRs:
Fixes a numerical instability bug in float8 rowwise quantization that causes gradient norm explosion and training
divergence when using bfloat16 model precision. The issue affects the combination of bfloat16 models with float8
rowwise configuration.
Issue:
The issue is
amax_to_scale
upscalesamax
to fp64 which creates precision error. Rowwise uses power-of-2 scaling which amplifies the error, leading to nan gradients.Fix: converted bfloat16 tensors to float32 before amax computation instead of removing the float64 upcast, so existing flow is not affected.
To test: change this line https://github.com/pytorch/torchtitan/blob/248aca2cac54b4ce5d1cd460dd6ec187a8c276b3/torchtitan/train.py#L153
to
model = self.train_spec.model_cls(model_args).to(torch.bfloat16)
and run this command in torchtitanbefore fix:

After fix

Fixes: pytorch/pytorch#150859