Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cachemask variant for fake_quantize_affine #500

Merged
merged 1 commit into from
Jul 17, 2024

Conversation

andrewor14
Copy link
Contributor

Summary: In QAT, we often wish to filter out the gradients corresponding to values outside the expected quantization range, for example:

# Forward
q = _quantize_affine_no_dtype_cast(...)
dq = _dequantize_affine_no_dtype_check(...)
mask = torch.logical_and((q >= quant_min), (q <= quant_max))

# Backward
grad = grad * mask

The existing fake_quantize_affine returns the dequantized values only, so callers do not have access to this mask. This commit adds the variant to this op that returns both the dequantized values and the mask, similar to
fake_quantize_per_tensor_affine_cachemask in core.

Test Plan:
python test/quantization/test_quant_primitives.py -k test_fake_quantize_affine_cachemask

Copy link

pytorch-bot bot commented Jul 12, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/500

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit d70f92c with merge base aef7e09 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 12, 2024
)

Args:
input (torch.Tensor): original float32, float16 or bfloat16 Tensor
Copy link
Contributor

@jerryzh168 jerryzh168 Jul 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can dedup the comments by: https://github.com/pytorch/pytorch/blob/06ebf87a1eca6c345f5e3e39b63c2ef487695043/torch/ao/quantization/observer.py#L582

e.g.

:func:`~torchao.quantization.quant_primitives.fake_quantize_affine`

?

@andrewor14 andrewor14 force-pushed the fake_quantize_affine_cachemask branch from 4e4dd12 to ab7f401 Compare July 15, 2024 20:08
General fake quantize op for quantization-aware training (QAT).
This is equivalent to calling `quantize_affine` + `dequantize_affine`
but without the dtype casts.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add a section for Args and link to fake_quantize_affine

outlier mask for intermediate quantized values
)

Please refer to :func:`~torchao.quantization.quant_primitives.fake_quantize_affine`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, I think we can move this to before Returns

Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG, had a comment on updating the docstring a bit

Summary: In QAT, we often wish to filter out the gradients
corresponding to values outside the expected quantization
range, for example:

```
q = _quantize_affine_no_dtype_cast(...)
dq = _dequantize_affine_no_dtype_check(...)
mask = torch.logical_and((q >= quant_min), (q <= quant_max))

grad = grad * mask
```

The existing `fake_quantize_affine` returns the dequantized
values only, so callers do not have access to this mask.
This commit adds the variant to this op that returns both
the dequantized values and the mask, similar to
`fake_quantize_per_tensor_affine_cachemask` in core.

Test Plan:
python test/quantization/test_quant_primitives.py -k test_fake_quantize_affine_cachemask
@andrewor14 andrewor14 force-pushed the fake_quantize_affine_cachemask branch from ab7f401 to d70f92c Compare July 16, 2024 23:09
@andrewor14 andrewor14 merged commit 03c4553 into main Jul 17, 2024
13 checks passed
dbyoung18 pushed a commit to dbyoung18/ao that referenced this pull request Jul 31, 2024
Summary: In QAT, we often wish to filter out the gradients
corresponding to values outside the expected quantization
range, for example:

```
q = _quantize_affine_no_dtype_cast(...)
dq = _dequantize_affine_no_dtype_check(...)
mask = torch.logical_and((q >= quant_min), (q <= quant_max))

grad = grad * mask
```

The existing `fake_quantize_affine` returns the dequantized
values only, so callers do not have access to this mask.
This commit adds the variant to this op that returns both
the dequantized values and the mask, similar to
`fake_quantize_per_tensor_affine_cachemask` in core.

Test Plan:
python test/quantization/test_quant_primitives.py -k test_fake_quantize_affine_cachemask
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants