Skip to content

Commit

Permalink
Convert OPT MatMuls with quantized inputs to MatMulInteger (#1585)
Browse files Browse the repository at this point in the history
* Convert OPT MatMuls with quantized inputs to MatMulInteger

* Quality check

* Return conversion count

* Add Mul node to rescale the quantized output back to FP32

* Quality fixes

---------

Co-authored-by: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com>
Co-authored-by: Alexandre Marques <alexandre@neuralmagic.com>
  • Loading branch information
3 people authored Jun 8, 2023
1 parent ecedec8 commit 1575944
Showing 1 changed file with 197 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -568,10 +568,195 @@ def _convert_quantizable_gemm(
remove_node_and_params_from_graph(model, gemm_node)


def _convert_quantizable_matmul(model: ModelProto):
def _convert_quantizable_matmuls_with_nonquantized_outputs(model: ModelProto):
"""
A pass for converting a MatMul into a quantized representation
This MatMul is the result of quantizing native torch.matmul using QATMatMul
A pass for converting a MatMul with quantized inputs into
a MatMulInteger
| Starting with:
| INPUT_0 INPUT_1
| | |
| QuantizeLinear QuantizeLinear
| | |
| DequantizeLinear DequantizeLinear
| | |
| MatMul
| |
| Add (optional)
| |
| OUTPUT
| We end up converting to:
| INPUT_0 INPUT_1
| | |
| QuantizeLinear QuantizeLinear
| | |
| | |
| MatMulInteger
| |
| Add (Optional)
| |
| Cast (Int32 --> FP32)
| |
| Mul
| |
| OUTPUT
"""

conversion_count = 0
matmul_nodes = [n for n in model.graph.node if n.op_type in ["MatMul"]]
graph = ONNXGraph(model)
for matmul_node in matmul_nodes:
#############
# Matching
#############

input_dequantize_nodes = [
graph.get_node_single_parent(matmul_node, i) for i in range(2)
]

# Make sure these input nodes are DequantizeLinear
if numpy.any(
[
(node is None or node.op_type != "DequantizeLinear")
for node in input_dequantize_nodes
]
):
continue

# Make sure their parents are QuantizeLinear
parents = [
graph.get_node_single_parent(node, 0) for node in input_dequantize_nodes
]
if numpy.any(
[
(parent is None or parent.op_type != "QuantizeLinear")
for parent in parents
]
):
continue

_LOGGER.debug(f"Matched quantizable MatMul: {matmul_node.name}")

# Create MatMulInteger node
node_0, node_1 = input_dequantize_nodes

input_nodes = [
node_0.input[0], # a
node_1.input[0], # b
node_0.input[2], # a_zero_point
node_1.input[2], # b_zero_point
]

matmul_int_op_node = onnx.helper.make_node(
"MatMulInteger",
input_nodes,
[f"{matmul_node.name}_quant_out"],
f"{matmul_node.name}_quant",
)
model.graph.node.append(matmul_int_op_node)

node_0_parameters = get_quantization_params(model, node_0)
node_1_parameters = get_quantization_params(model, node_1)

output_scale = node_0_parameters.scale * node_1_parameters.scale

has_bias = False

# Check if is followed by Add node (bias)
bias_add_node = graph.get_node_single_child(matmul_node)
if bias_add_node is not None and bias_add_node.op_type == "Add":
bias_initializer = get_init_by_name(
model, bias_add_node.input[1]
) or get_init_by_name(model, bias_add_node.input[0])
if bias_initializer is not None:
# check if bias is finite
bias_initializer = numpy_helper.to_array(bias_initializer)
if numpy.all(numpy.isfinite(bias_initializer)):
# Create initializer for quantized bias
quantized_bias_initializer_name = f"{bias_initializer.name}_quant"
has_bias = True

bias_zero_point = 0
quantized_bias = _quantize_array(
bias_initializer,
output_scale,
bias_zero_point,
dtype=numpy.int32,
)
quantized_bias_initializer = numpy_helper.from_array(
quantized_bias,
name=quantized_bias_initializer_name,
)
model.graph.initializer.append(quantized_bias_initializer)

# Create new Add node for quantized bias
quantized_add_node_name = f"{bias_add_node.name}_quant"
quantized_add_node = onnx.helper.make_node(
"Add",
[matmul_int_op_node.output[0], quantized_bias_initializer_name],
[f"{quantized_add_node_name}_output"],
quantized_add_node_name,
)
model.graph.node.append(quantized_add_node)

# Casting MatMulInteger INT32 output to FP32

cast_node_name = f"{matmul_node.name}_cast"
cast_node_input = (
quantized_add_node.output if has_bias else matmul_int_op_node.output
)
cast_node = onnx.helper.make_node(
"Cast",
cast_node_input,
[f"{cast_node_name}_output"],
cast_node_name,
to=getattr(onnx.TensorProto, "FLOAT"), # get Float32 enum id
)
model.graph.node.append(cast_node)

output_scale_initializer_name = f"{matmul_node.name}.output_scale"
model.graph.initializer.append(
numpy_helper.from_array(
numpy.asarray(output_scale),
name=output_scale_initializer_name,
)
)

mul_node_output = bias_add_node.output if has_bias else matmul_node.output
mul_node = onnx.helper.make_node(
"Mul",
[cast_node.output[0], output_scale_initializer_name],
mul_node_output,
f"{matmul_node.name}_scale",
)
model.graph.node.append(mul_node)

for node in input_dequantize_nodes:
delete_quant_node(model, node)

# delete original MatMul node
remove_node_and_params_from_graph(model, matmul_node)

# delete original Add node
if has_bias:
remove_node_and_params_from_graph(model, bias_add_node)

conversion_count += 1

if matmul_nodes:
_LOGGER.info(
f"Converted {conversion_count} quantizable MatMul "
"(A8A8 inputs, FP output) ops to MatMulInteger"
)
graph = ONNXGraph(model)
graph.delete_unused_initializers()


def _convert_quantizable_matmul_with_quantized_outputs(model: ModelProto):
"""
A pass for converting a MatMul with quantized inputs and outputs into
a QLinearMatMul. This MatMul is the result of quantizing native
torch.matmul using QATMatMul
| Starting with:
| INPUT_0 INPUT_1
Expand Down Expand Up @@ -732,9 +917,17 @@ def _convert_quantizable_matmul(model: ModelProto):

if matmul_nodes:
_LOGGER.info(
f"Converted {conversion_count} quantizable MatMul ops " "to QLinearMatMul"
f"Converted {conversion_count} quantizable MatMul with quantized outputs "
"to QLinearMatMul"
)

return conversion_count


def _convert_quantizable_matmul(model: ModelProto):
_convert_quantizable_matmul_with_quantized_outputs(model)
_convert_quantizable_matmuls_with_nonquantized_outputs(model)


def _add_quantized_conv_matmul_add_ops(
model: ModelProto,
Expand Down

0 comments on commit 1575944

Please sign in to comment.