Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable ability to control whether or not to quantize the bias #14549

Merged
merged 2 commits into from
Feb 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion onnxruntime/python/tools/quantization/qdq_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,13 @@ def __init__(
False if "AddQDQPairToWeight" not in extra_options else extra_options["AddQDQPairToWeight"]
)

# Some scenarios do not need the bias quantized. For example, in the case of Quantization Aware Training,
# quantizing the bias is not needed. This is because in QAT, all model parameters are expected to be in
# floating point format. To that end, we can use the FakeQuant operator for weights and activations that
# can always have QDQ pairs (by using AddQDQPairToWeight). But for biases in a quantized model, we can't use
# FakeQuant because it only ever appears before a DQ (since it is quantized as int32).
self.quantize_bias = True if "QuantizeBias" not in extra_options else extra_options["QuantizeBias"]

# The default behavior is that multiple nodes can share a QDQ pair as their inputs.
# In TRT, QDQ pair can’t be shared between nodes, so it will create dedicated QDQ pairs for each node.
self.dedicated_qdq_pair = (
Expand Down Expand Up @@ -211,7 +218,8 @@ def quantize_model(self):

self._quantize_normal_tensors()
self._quantize_sharing_param_tensors()
self._quantize_bias_tensors()
if self.quantize_bias:
self._quantize_bias_tensors()
self.remove_nodes()
if not self.add_qdq_pair_to_weight:
self.model.clean_initializers()
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/python/tools/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,11 @@ def __init__(
Default is 0.01. Constant smoothing factor to use when computing the moving average of the
minimum and maximum values. Effective only when the calibration method selected is MinMax and
when CalibMovingAverage is set to True.
QuantizeBias = True/False :
Default is True which quantizes floating-point biases and it solely inserts
a DeQuantizeLinear node. If False, it remains floating-point bias and does not insert
any quantization nodes associated with biases.
This extra option is only effective when quant_format is QuantFormat.QDQ.
execution_provider : A enum indicates the Execution Provider such as: CPU, TRT, NNAPI, SNE, etc.
Raises:
ValueError: Raise ValueError if execution provider is unknown
Expand Down