-
Notifications
You must be signed in to change notification settings - Fork 362
Skip expanding scales for rowwise fp8 quantize #2950
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
780003d to
935ac1a
Compare
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2950
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 4cf5c90 with merge base 4872c4f ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| return scale | ||
|
|
||
| # For rowwise quantization, just return the scale as is | ||
| if scale.shape[:-1] == target_shape[:-1] and scale.shape[-1] == 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you could probably do something fun, like
def is_trivial_expandable(scale, target_shape):
return all(a == b or a == 1 for a, b in zip(scale.shape, target_shape))**Summary:** #2253 added a step in `quantize_affine_float8` to expand the scales for blockwise quantization. The purpose of this step is to make the scales always broadcastable with the input tensor. However, this is unnecessary for rowwise quantization, which already has broadcastable shapes, e.g. ``` scale = [32, 1] input = [32, 16] ``` Today, we will `repeat_interleave` the above scales to pad the scale tensor until it reaches `[32, 16]`, which adds non-trivial memory and latency overhead. This commit adds a fast path to skip this expanding step if we detect rowwise quantization. **Test Plan:** ``` python test/quantization/test_quant_primitives.py -k test_maybe_expand_scale_to_tensor_shape ``` Also compared fine-tuning Qwen3-1.7B with fp8-fp8 QAT using batch size 32 on a single H100 GPU: - Before: 25.34 GB peak memory, 3047.25 tok/s - After: 22.53 GB peak memory, 3358.49 tok/s - This PR uses 11.1% less memory and is 10.2% faster
935ac1a to
4cf5c90
Compare
Summary: #2253 added a step in
quantize_affine_float8to expand the scales for blockwise quantization. The purpose of this step is to make the scales always broadcastable with the input tensor. However, this is unnecessary for rowwise quantization, which already has broadcastable shapes, e.g.Today, we will
repeat_interleavethe above scales to pad the scale tensor until it reaches[32, 16], which adds non-trivial memory and latency overhead. This commit adds a fast path to skip this expanding step if we detect rowwise quantization.Test Plan:
Also compared fine-tuning Qwen3-1.7B with fp8-fp8 QAT using batch size 32 on a single H100 GPU: