Skip to content

Conversation

@andrewor14
Copy link
Contributor

@andrewor14 andrewor14 commented Sep 17, 2025

Summary: Add support to pass scales and zero points learned during QAT range learning to the PTQ base config. Currently only the following configs support this feature:

IntxWeightOnlyConfig
Int8DynamicActivationInt4WeightConfig
Int8DynamicActivationIntxWeightConfig

During the convert phase, QAT will detect if range learning was used during training, and pass the learned scales and zero points as custom qparams to the quantized tensor subclass, so PTQ will produce more consistent numerics.

Fixes part of #2271.

Test Plan:

python test/quantization/test_qat.py -k test_range_learning_convert_pass_qparams

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 17, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3022

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 75c6c8d with merge base 9e5059e (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@andrewor14 andrewor14 marked this pull request as draft September 17, 2025 17:25
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 17, 2025
@andrewor14 andrewor14 force-pushed the try-range-learning-pass-qparams branch from 2f0b504 to f061ae7 Compare September 17, 2025 17:26
@andrewor14 andrewor14 added the topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) label Sep 17, 2025
@facebook-github-bot
Copy link
Contributor

@andrewor14 has imported this pull request. If you are a Meta employee, you can view this in D82654987.

@andrewor14 andrewor14 force-pushed the try-range-learning-pass-qparams branch from f061ae7 to c6b0fe7 Compare September 17, 2025 17:34
**kwargs,
)
if scale_dtype is not None and scale_dtype != weight.dtype:
_adjust_scale_dtype_in_intx_unpacked_tensor(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Presumably you don't want this to run if you passed custom scales?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked into this a bit and I think it's actually fine to run this? Basically we use the custom scale first and then cast it to the configured scale dtype (otherwise custom_scale will have to leak into this file, which I hope to avoid)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you're right

@andrewor14 andrewor14 force-pushed the try-range-learning-pass-qparams branch from c6b0fe7 to adb966c Compare September 17, 2025 18:44
@facebook-github-bot
Copy link
Contributor

@andrewor14 has imported this pull request. If you are a Meta employee, you can view this in D82654987.

@andrewor14 andrewor14 force-pushed the try-range-learning-pass-qparams branch from adb966c to 9d0aadd Compare September 17, 2025 20:20
@andrewor14 andrewor14 changed the title [draft] Pass QAT learned qparams in convert Pass QAT learned qparams in convert Sep 17, 2025
@facebook-github-bot
Copy link
Contributor

@andrewor14 has imported this pull request. If you are a Meta employee, you can view this in D82654987.

@andrewor14 andrewor14 marked this pull request as ready for review September 17, 2025 20:52
is_dynamic=False,
range_learning=True,
scale_precision=torch.float32,
zero_point_precision=torch.float32,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not use float32 zeros we we want to lower to XNNPACK.

We should convert the int32 zeros to int8 when passing in custom_zero_point


def _int8_dynamic_activation_intx_weight_quantize_tensor(weight, bias, config):
def _int8_dynamic_activation_intx_weight_quantize_tensor(
weight, bias, config, **kwargs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

spell out the kwargs might be better?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

like tensor_subclass_kwargs?

Copy link
Contributor

@jerryzh168 jerryzh168 Sep 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, I mean use custom_scale and custom_zero_point, since kwargs is more complicated than this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@andrewor14 andrewor14 force-pushed the try-range-learning-pass-qparams branch 3 times, most recently from aae7eee to ecd9349 Compare September 18, 2025 20:30
@facebook-github-bot
Copy link
Contributor

@andrewor14 has imported this pull request. If you are a Meta employee, you can view this in D82654987.

Comment on lines 865 to 866
if custom_zero_point.dtype == torch.int32:
custom_zero_point = custom_zero_point.to(torch.int8)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we do assert instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is to make the flow work by default without user having to specify the zero point dtype (QAT's default is int32), so adding an assert will fail that


if config.version == 2:
if config.intx_packing_format == IntxPackingFormat.UNPACKED_TO_INT8:
if custom_zero_point.dtype == torch.int32:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

**Summary:** Add support to pass scales and zero points
learned during QAT range learning to the PTQ base config.
Currently only the following configs support this feature:

```
IntxWeightOnlyConfig
Int8DynamicActivationInt4WeightConfig
Int8DynamicActivationIntxWeightConfig
```

During the convert phase, QAT will detect if range learning
was used during training, and pass the learned scales and
zero points as custom qparams to the quantized tensor
subclass, so PTQ will produce more consistent numerics.

Fixes part of #2271.

**Test Plan:**
```
python test/quantization/test_qat.py -k
test_range_learning_convert_pass_qparams
```
@andrewor14 andrewor14 force-pushed the try-range-learning-pass-qparams branch from ecd9349 to 75c6c8d Compare September 19, 2025 17:10
@facebook-github-bot
Copy link
Contributor

@andrewor14 has imported this pull request. If you are a Meta employee, you can view this in D82654987.

@andrewor14 andrewor14 merged commit ae12e42 into main Sep 19, 2025
19 of 20 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants