-
Notifications
You must be signed in to change notification settings - Fork 358
Pass QAT learned qparams in convert #3022
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3022
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 75c6c8d with merge base 9e5059e ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
2f0b504 to
f061ae7
Compare
|
@andrewor14 has imported this pull request. If you are a Meta employee, you can view this in D82654987. |
f061ae7 to
c6b0fe7
Compare
torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py
Outdated
Show resolved
Hide resolved
| **kwargs, | ||
| ) | ||
| if scale_dtype is not None and scale_dtype != weight.dtype: | ||
| _adjust_scale_dtype_in_intx_unpacked_tensor( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Presumably you don't want this to run if you passed custom scales?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I looked into this a bit and I think it's actually fine to run this? Basically we use the custom scale first and then cast it to the configured scale dtype (otherwise custom_scale will have to leak into this file, which I hope to avoid)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you're right
c6b0fe7 to
adb966c
Compare
|
@andrewor14 has imported this pull request. If you are a Meta employee, you can view this in D82654987. |
adb966c to
9d0aadd
Compare
|
@andrewor14 has imported this pull request. If you are a Meta employee, you can view this in D82654987. |
test/quantization/test_qat.py
Outdated
| is_dynamic=False, | ||
| range_learning=True, | ||
| scale_precision=torch.float32, | ||
| zero_point_precision=torch.float32, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should not use float32 zeros we we want to lower to XNNPACK.
We should convert the int32 zeros to int8 when passing in custom_zero_point
torchao/quantization/quant_api.py
Outdated
|
|
||
| def _int8_dynamic_activation_intx_weight_quantize_tensor(weight, bias, config): | ||
| def _int8_dynamic_activation_intx_weight_quantize_tensor( | ||
| weight, bias, config, **kwargs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
spell out the kwargs might be better?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
like tensor_subclass_kwargs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no, I mean use custom_scale and custom_zero_point, since kwargs is more complicated than this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
aae7eee to
ecd9349
Compare
|
@andrewor14 has imported this pull request. If you are a Meta employee, you can view this in D82654987. |
torchao/quantization/quant_api.py
Outdated
| if custom_zero_point.dtype == torch.int32: | ||
| custom_zero_point = custom_zero_point.to(torch.int8) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we do assert instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is to make the flow work by default without user having to specify the zero point dtype (QAT's default is int32), so adding an assert will fail that
torchao/quantization/quant_api.py
Outdated
|
|
||
| if config.version == 2: | ||
| if config.intx_packing_format == IntxPackingFormat.UNPACKED_TO_INT8: | ||
| if custom_zero_point.dtype == torch.int32: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
**Summary:** Add support to pass scales and zero points learned during QAT range learning to the PTQ base config. Currently only the following configs support this feature: ``` IntxWeightOnlyConfig Int8DynamicActivationInt4WeightConfig Int8DynamicActivationIntxWeightConfig ``` During the convert phase, QAT will detect if range learning was used during training, and pass the learned scales and zero points as custom qparams to the quantized tensor subclass, so PTQ will produce more consistent numerics. Fixes part of #2271. **Test Plan:** ``` python test/quantization/test_qat.py -k test_range_learning_convert_pass_qparams ```
ecd9349 to
75c6c8d
Compare
|
@andrewor14 has imported this pull request. If you are a Meta employee, you can view this in D82654987. |
Summary: Add support to pass scales and zero points learned during QAT range learning to the PTQ base config. Currently only the following configs support this feature:
During the convert phase, QAT will detect if range learning was used during training, and pass the learned scales and zero points as custom qparams to the quantized tensor subclass, so PTQ will produce more consistent numerics.
Fixes part of #2271.
Test Plan: