You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
**Summary:** #2253 added a step
in `quantize_affine_float8` to expand the scales for blockwise
quantization. The purpose of this step is to make the scales
always broadcastable with the input tensor. However, this is
unnecessary for rowwise quantization, which already has
broadcastable shapes, e.g.
```
scale = [32, 1]
input = [32, 16]
```
Today, we will `repeat_interleave` the above scales to pad
the scale tensor until it reaches `[32, 16]`, which adds
non-trivial memory and latency overhead. This commit adds a
fast path to skip this expanding step if we detect rowwise
quantization.
**Test Plan:**
```
python test/quantization/test_quant_primitives.py -k test_maybe_expand_scale_to_tensor_shape
```
Also compared fine-tuning Qwen3-1.7B with fp8-fp8 QAT using
batch size 32 on a single H100 GPU:
- Before: 25.34 GB peak memory, 3047.25 tok/s
- After: 22.53 GB peak memory, 3358.49 tok/s
- This PR uses 11.1% less memory and is 10.2% faster
0 commit comments