-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for QLinearMul ONNX op #8773
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -3279,6 +3279,40 @@ def get_scalar(x, dtype="float32"): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return _qnn.op.quantize(out, c_scale, c_zero_point, out_dtype=dtype) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
class QLinearMul(OnnxOpConverter): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"""Operator converter for QLinearMul from Microsoft onnxruntime contrib opset.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
@classmethod | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def _impl_v10(cls, inputs, attr, params): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def get_scalar(x, dtype="float32"): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if isinstance(x, _expr.Var) and x.name_hint in params: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return _op.const(params[x.name_hint].numpy(), dtype) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
rank = len(infer_shape(x)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
assert rank <= 1, "QLinearMul scale and zero_point input must be scalars" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if rank == 1: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
x = _op.squeeze(x, [0]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return _op.cast(x, dtype) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
a = inputs[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
a_scale = get_scalar(inputs[1]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
a_zero_point = get_scalar(inputs[2], "int32") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
b = inputs[3] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
b_scale = get_scalar(inputs[4]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
b_zero_point = get_scalar(inputs[5], "int32") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
y_scale = fold_constant(get_scalar(inputs[6])) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
y_zero_point = get_scalar(inputs[7], "int32") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
dtype = infer_type(a).checked_type.dtype | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
## Onnxruntime doesn't actually do this op in integer, they dequantize to fp32 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
## and then requantize afer | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
## https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/mlas/lib/qlmul.cpp | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+3307
to
+3309
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Even if onnxruntime is performing fp32 operations, is there any reason to do the same here? Wouldn't it be better (at least some what for the performance) to requantize both inputs 'a' and 'b' as per output scale and zero_point and then perform integer matmul? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have a pass ( tvm/python/tvm/relay/transform/transform.py Lines 1177 to 1204 in a31ebf7
BTW - this is elementwise multiplication, not matmul There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank makes sense. Thanks! |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
a = _qnn.op.dequantize(inputs[0], a_scale, a_zero_point) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
b = _qnn.op.dequantize(inputs[3], b_scale, b_zero_point) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
out = _op.multiply(a, b) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return _qnn.op.quantize(out, y_scale, y_zero_point, out_dtype=dtype) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
class ConvInteger(OnnxOpConverter): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"""Operator converter for ConvInteger.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -3605,6 +3639,7 @@ def _get_convert_map(opset): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"ReverseSequence": ReverseSequence.get_converter(opset), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"QLinearConv": QLinearConv.get_converter(opset), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"QLinearAdd": QLinearAdd.get_converter(opset), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"QLinearMul": QLinearMul.get_converter(opset), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"ConvInteger": ConvInteger.get_converter(opset), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Random number generation. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"RandomUniform": RandomUniform.get_converter(opset), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
😱
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems to be the ORT implementation of QLinearMul, not QLinearMatMul?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Translated everything to QLinearMul instead 🤔 , let me know if it looks right to you now