-
Notifications
You must be signed in to change notification settings - Fork 368
Add float8 FakeQuantizeConfig and FakeQuantizer #2735
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2735
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 2 PendingAs of commit 6b85230 with merge base a1a9632 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
df874d5 to
481ac90
Compare
73bca60 to
7460a2d
Compare
| dtype=base_config.weight_dtype, | ||
| granularity=weight_granularity, | ||
| ) | ||
| elif isinstance(base_config, Float8ActivationInt4WeightConfig): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jerryzh168 can you confirm these config settings?
|
|
||
| # TODO: don't register as custom op? | ||
| @_register_custom_op(quant_lib, False) | ||
| def _dequantize_affine_float8( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jerryzh168 I'm seeing this warning. Maybe should also skip registering this custom op?
/home/andrewor/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/autograd/graph.py:824:
UserWarning: torchao::dequantize_affine_float8: an autograd kernel was not registered to the
Autograd key(s) but we are trying to backprop through it. This may lead to silently incorrect behavior.
This behavior is deprecated and will be removed in a future version of PyTorch.
If your operator is differentiable, please ensure you have registered an autograd kernel to
the correct Autograd key (e.g. DispatchKey::Autograd, DispatchKey::CompositeImplicitAutograd).
If your operator is not differentiable, or to squash this warning and use the previous behavior,
please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd.
(Triggered internally at /pytorch/torch/csrc/autograd/autograd_not_implemented_fallback.cpp:62.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, we didn't know this would be a problem, we can do
ao/torchao/quantization/quant_primitives.py
Line 361 in 1dca638
| @register_custom_op |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, will relax in a separate PR
| ) | ||
| else: | ||
| # targeting tinygemm kernel | ||
| assert base_config.VERSION == 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just support version 2, to minimize complexity?
|
can the test plan include training a real model and verifying loss converges |
7f19d27 to
1162ac3
Compare
Yes this is in progress |
3129101 to
204e99b
Compare
test/quantization/test_qat.py
Outdated
| _get_qmin_qmax, | ||
| ) | ||
| from torchao.quantization.quant_api import ( | ||
| Float8ActivationInt4WeightConfig, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we just renamed this one
test/quantization/test_qat.py
Outdated
| sqnr = compute_error(out, out_expected) | ||
| self.assertGreater(sqnr, 16) | ||
|
|
||
| @parameterized.expand([(PerRow(),), (PerTensor(),)]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: why not use
| @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) |
test/quantization/test_qat.py
Outdated
| if "fbgemm-gpu-genai" in str(e): | ||
| self.skipTest("fbgemm-gpu-genai not available") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we can skip the test when fbgemm-gpu-genai is not installed:
| @unittest.skipIf( |
test/quantization/test_qat.py
Outdated
| try: | ||
| quantize_(m, QATConfig(base_config, step="prepare")) | ||
| quantize_(m, QATConfig(base_config, step="convert")) | ||
| m(*example_inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this happen between prepare and convert as well
| granularity=weight_granularity, | ||
| ) | ||
| elif isinstance(base_config, Float8DynamicActivationInt4WeightConfig): | ||
| act_config = Float8FakeQuantizeConfig( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one thing we should pay extra attention here is whether the simulation works for int4 preshuffled tensor as well I think, we need some numerics testing to make sure
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added an fp8-int4 numerics test
| max_abs = torch.clamp(max_abs, min=hp_value_lb, max=hp_value_ub) | ||
| scale = max_abs / quant_max | ||
| else: | ||
| # rowwise |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is not necessarily rowwise I think, I believe this includes all granularities, and the len(block_size) == 0 is more of special case for tensorwise quant, I'm not sure where it comes from and whether it's needed, we could try to trace it and see if it can be removed as well to reduce complexity
204e99b to
d669f71
Compare
f002586 to
e1823c6
Compare
| Float8FakeQuantizeConfig(granularity=PerToken()) | ||
|
|
||
| @parametrize("granularity", [PerTensor(), PerRow()]) | ||
| def test_float8_fake_quantize(self, granularity: Granularity): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add a same test for fp8_int4?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added some sqnr comparison against PTQ fp8_int4
e1823c6 to
bd0bd9d
Compare
**Summary:** This commit adds a QAT path for float8, using the
same primitives as `torchao.quantization.Float8Tensor` targeting
the following PTQ configs:
- `Float8DynamicActivationFloat8WeightConfig`
- `Float8DynamicActivationInt4WeightConfig`
Usage:
```
from torchao.quantization.granularity import PerRow
from torchao.quantization.qat import quantize_, QATConfig
base_config = Float8DynamicActivationFloat8WeightConfig(
torch.float8_e4m3fn, PerRow(),
)
quantize_(model, QATConfig(base_config, step="prepare"))
quantize_(model, QATConfig(base_config, step="convert"))
```
OR
```
from torchao.quantization.granularity import PerRow
from torchao.quantization.qat import (
Float8FakeQuantizeConfig,
QATConfig,
quantize_,
)
dtype = torch.float8_e4m3fn
granularity = PerRow()
quantize_(model, QATConfig(
activation_config=Float8FakeQuantizeConfig(dtype, granularity),
weight_config=Float8FakeQuantizeConfig(dtype, granularity),
step="prepare",
)
# convert (same as above, not shown)
```
**Test Plan:**
```
python test/quantization/test_qat.py -k test_float8_fake_quantize_config
python test/quantization/test_qat.py -k test_float8_fake_quantize
python test/quantization/test_qat.py -k test_quantize_api_fp8_fp8
python test/quantization/test_qat.py -k test_quantize_api_fp8_int4
```
bd0bd9d to
6b85230
Compare
**Summary:** This commit adds a QAT path for float8, using the
same primitives as `torchao.quantization.Float8Tensor` targeting
the following PTQ configs:
- `Float8DynamicActivationFloat8WeightConfig`
- `Float8DynamicActivationInt4WeightConfig`
Usage:
```
from torchao.quantization.granularity import PerRow
from torchao.quantization.qat import quantize_, QATConfig
base_config = Float8DynamicActivationFloat8WeightConfig(
torch.float8_e4m3fn, PerRow(),
)
quantize_(model, QATConfig(base_config, step="prepare"))
quantize_(model, QATConfig(base_config, step="convert"))
```
OR
```
from torchao.quantization.granularity import PerRow
from torchao.quantization.qat import (
Float8FakeQuantizeConfig,
QATConfig,
quantize_,
)
dtype = torch.float8_e4m3fn
granularity = PerRow()
quantize_(model, QATConfig(
activation_config=Float8FakeQuantizeConfig(dtype, granularity),
weight_config=Float8FakeQuantizeConfig(dtype, granularity),
step="prepare",
)
# convert (same as above, not shown)
```
**Test Plan:**
```
python test/quantization/test_qat.py -k test_float8_fake_quantize_config
python test/quantization/test_qat.py -k test_float8_fake_quantize
python test/quantization/test_qat.py -k test_quantize_api_fp8_fp8
python test/quantization/test_qat.py -k test_quantize_api_fp8_int4
```
Summary: This commit adds a QAT path for float8, using the same primitives as
torchao.quantization.Float8Tensortargeting the following PTQ configs:Float8DynamicActivationFloat8WeightConfigFloat8ActivationInt4WeightConfigUsage:
OR
Test Plan:
Identical outputs between normal bf16 and QAT fine-tuning for both fp8-fp8 and fp8-int4, reproduced on Llama3.1 using this unsloth notebook. Loss curves also overlap almost exactly (not shown):