forked from pytorch/ao
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for int4 weight-only QAT (pytorch#383)
* Add cachemask variant for fake_quantize_affine Summary: In QAT, we often wish to filter out the gradients corresponding to values outside the expected quantization range, for example: ``` q = _quantize_affine_no_dtype_cast(...) dq = _dequantize_affine_no_dtype_check(...) mask = torch.logical_and((q >= quant_min), (q <= quant_max)) grad = grad * mask ``` The existing `fake_quantize_affine` returns the dequantized values only, so callers do not have access to this mask. This commit adds the variant to this op that returns both the dequantized values and the mask, similar to `fake_quantize_per_tensor_affine_cachemask` in core. Test Plan: python test/quantization/test_quant_primitives.py -k test_fake_quantize_affine_cachemask * Add support for int4 weight-only QAT Summary: This commit adds support for int4 weight-only QAT, which simulates the numerics of the existing Int4WeightOnlyQuantizer. The main motivation for this is to provide an end-to-end path for running QAT and lowering to the efficient int4 tinygemm cuda kernel. To enable this, we have to add new fake quantization primitives to match the numerics of the tinygemm kernel, and this required refactoring existing quant primitives to skip dtype casting. Test Plan: python test/quantization/test_qat.py -k test_qat_4w_linear Reviewers: jerryzh168, msaroufim Subscribers: jerryzh168, msaroufim, HDCharles, supriyar
- Loading branch information
1 parent
aaf209d
commit 5fe9967
Showing
3 changed files
with
477 additions
and
41 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.