-
Notifications
You must be signed in to change notification settings - Fork 363
Commit 757f6fd
committed
Update on "Support NVFP4 dynamic per tensor scale"
**Summary:** This commit adds an option for the existing
`NVFP4InferenceConfig` to dynamically compute an appropriate
fp32 per tensor scale to support the two level scaling
according to the NVFP4 specification:
https://developer.nvidia.com/blog/introducing-nvfp4-for-efficient-and-accurate-low-precision-inference/.
While two level scaling is supported in `NVFP4Tensor`, today
there is no config API for users to call this. The existing
`NVFP4InferenceConfig` only supports single level scaling
because including an explicit `per_tensor_scale` field would
make serialization tricky.
In the future, we should add an end-to-end calibration flow
so users can compute an appropriate per tensor scale for the
activations first, and then pass this to `NVFP4Tensor` as a
static scale, similar to the proposal in #2572.
**Test Plan:**
```
pytest test/prototype/mx_formats/test_inference_workflow.py -k test_inference_workflow_nvfp4
pytest test/quantization/test_qat.py -k test_quantize_api_nvfp4
```
Also did a quick benchmark before and after:
```
import copy
import time
import torch
from torchao.quantization import quantize_
from torchao.prototype.mx_formats import NVFP4InferenceConfig
m_mx1 = torch.nn.Linear(64, 256, bias=True, dtype=torch.bfloat16, device="cuda")
m_mx2 = copy.deepcopy(m_mx1)
config1 = NVFP4InferenceConfig(use_dynamic_per_tensor_scale=False)
config2 = NVFP4InferenceConfig(use_dynamic_per_tensor_scale=True)
quantize_(m_mx1, config=config1)
quantize_(m_mx2, config=config2)
m_mx1 = torch.compile(m_mx1, fullgraph=True, backend="aot_eager")
m_mx2 = torch.compile(m_mx2, fullgraph=True, backend="aot_eager")
start = time.time()
for _ in range(1000):
m_mx1(torch.randn(128, 64, device="cuda", dtype=torch.bfloat16))
print("No per_tensor_scale = ", time.time() - start, "seconds")
start = time.time()
for _ in range(1000):
m_mx2(torch.randn(128, 64, device="cuda", dtype=torch.bfloat16))
print("With per_tensor_scale = ", time.time() - start, "seconds")
```
On a single B200:
```
No per_tensor_scale = 1.2855589389801025 seconds
With per_tensor_scale = 1.3009123802185059 seconds
```
[ghstack-poisoned]1 parent 6fc1dab commit 757f6fdCopy full SHA for 757f6fd
File tree
Expand file treeCollapse file tree
0 file changed
+0
-0
lines changedOpen diff view settings
Filter options
Expand file treeCollapse file tree
0 file changed
+0
-0
lines changedOpen diff view settings
0 commit comments