diff --git a/onnxoptimizer/passes/fuse_add_bias_into_conv.h b/onnxoptimizer/passes/fuse_add_bias_into_conv.h index df064973f..6e7cfc876 100644 --- a/onnxoptimizer/passes/fuse_add_bias_into_conv.h +++ b/onnxoptimizer/passes/fuse_add_bias_into_conv.h @@ -38,11 +38,22 @@ struct FuseAddBiasIntoConv final : public PredicateBasedPass { } static Node *makeSqueezeOrUnsqueeze(Graph &graph, std::vector &axes, Value *input, Node *target_node, - BuiltinSymbol k) { + BuiltinSymbol k, bool is_input_qdq) { assert(k == kSqueeze || k == kUnsqueeze); Node *squeeze = graph.create(k, 1); - int opset_version = getOpsetVersion(graph); + Node *dequant_node = nullptr; + Node *quant_node = nullptr; + // insert squeeze op before qdq + if (is_input_qdq) { + dequant_node = input->node(); + quant_node = dequant_node->input(0)->node(); + target_node = quant_node; + input = target_node->input(0); + dequant_node->output()->clearMetadata(); + quant_node->output()->clearMetadata(); + } squeeze->addInput(input); + int opset_version = getOpsetVersion(graph); int version_threshold = 13; if (opset_version < version_threshold && opset_version != 0) { squeeze->is_(kaxes, std::move(axes)); @@ -54,7 +65,13 @@ struct FuseAddBiasIntoConv final : public PredicateBasedPass { Value *tv = graph.addInitializerAndInput(t); squeeze->addInput(tv); } + if (is_input_qdq) { + quant_node->replaceInput(0, squeeze->output()); + } squeeze->insertBefore(target_node); + if (is_input_qdq) { + return dequant_node; + } return squeeze; } bool runTransform(Node *n, Graph &graph, @@ -115,13 +132,13 @@ struct FuseAddBiasIntoConv final : public PredicateBasedPass { if (bias_shape.size() > 1) { std::vector axes(bias_shape.size() - 1); std::iota(axes.begin(), axes.end(), 0); - Node *squeeze = makeSqueezeOrUnsqueeze(graph, axes, conv_3rd_input, - orig_conv->node(), kSqueeze); + Node *squeeze = makeSqueezeOrUnsqueeze( + graph, axes, conv_3rd_input, orig_conv->node(), kSqueeze, false); conv_3rd_input = squeeze->output(); } else if (bias_shape.size() == 0) { std::vector axes = {0}; - Node *unsqueeze = makeSqueezeOrUnsqueeze(graph, axes, conv_3rd_input, - orig_conv->node(), kUnsqueeze); + Node *unsqueeze = makeSqueezeOrUnsqueeze( + graph, axes, conv_3rd_input, orig_conv->node(), kUnsqueeze, false); conv_3rd_input = unsqueeze->output(); } if (M > 1) { @@ -149,17 +166,25 @@ struct FuseAddBiasIntoConv final : public PredicateBasedPass { bias_shape[1 + bias_shape.size() - static_cast(rank)] .dim == M) { ONNX_ASSERT(bias_shape.size() > 1); + const bool is_input_qdq = + orig_bias->node()->kind() == Symbol("DequantizeLinear") && + orig_bias->node()->input(0)->node()->kind() == + Symbol("QuantizeLinear"); if (orig_bias->node()->kind() != kParam && orig_conv->node()->isBefore(orig_bias->node())) { + if (is_input_qdq) { + orig_bias->node()->input(0)->node()->moveBefore(orig_conv->node()); + } orig_bias->node()->moveBefore(orig_conv->node()); } std::vector axes(bias_shape.size()); std::iota(axes.begin(), axes.end(), static_cast(0)); axes.erase(axes.begin() + (1 + bias_shape.size() - static_cast(rank))); - Node *squeeze = makeSqueezeOrUnsqueeze(graph, axes, orig_bias, - orig_conv->node(), kSqueeze); - orig_conv->node()->addInput(squeeze->output()); + + Node *new_bias = makeSqueezeOrUnsqueeze( + graph, axes, orig_bias, orig_conv->node(), kSqueeze, is_input_qdq); + orig_conv->node()->addInput(new_bias->output()); } else { return false; } diff --git a/onnxoptimizer/test/optimizer_test.py b/onnxoptimizer/test/optimizer_test.py index 517eb7920..7273a80dc 100644 --- a/onnxoptimizer/test/optimizer_test.py +++ b/onnxoptimizer/test/optimizer_test.py @@ -1150,6 +1150,41 @@ def test_fuse_add_bias_into_conv_with_non_constant_bias(self): assert optimized_model.graph.node[2].op_type == 'Conv' assert optimized_model.graph.output[0].name == 'C' + # type: () -> None + def test_fuse_add_bias_into_conv_with_quanted_bias(self): + nodes = [helper.make_node("Conv", ["X", "Y"], ["Z"]), + helper.make_node("QuantizeLinear", ["A", "scale", "zero_point"], ["B"], axis=0), + helper.make_node("DequantizeLinear", ["B", "scale", "zero_point"], ["C"], axis=0), + helper.make_node("Add", ["Z", "C"], ["D"])] + graph = helper.make_graph( + nodes, + "test", + [helper.make_tensor_value_info("X", TensorProto.FLOAT, (1, 5, 3, 3)), + helper.make_tensor_value_info( + "Y", TensorProto.FLOAT, (16, 5, 3, 3)), + helper.make_tensor_value_info("A", TensorProto.FLOAT, (16, 1, 1))], + [helper.make_tensor_value_info( + "D", TensorProto.FLOAT, (1, 16, 1, 1))], + [helper.make_tensor("scale", TensorProto.FLOAT, + dims=(16,), + vals=np.random.rand(16).astype(np.float32).tobytes(), + raw=True), + helper.make_tensor("zero_point", TensorProto.INT8, + dims=(16,), + vals=np.zeros([16]).astype(np.int8).tobytes(), + raw=True)], + value_info=[helper.make_tensor_value_info( + "C", TensorProto.FLOAT, (16, 1, 1))] + ) + optimized_model = self._optimized(graph, ["fuse_add_bias_into_conv"], opset_imports=[helper.make_opsetid("", 13)]) + + assert len(list(optimized_model.graph.node)) == 4 + assert optimized_model.graph.node[0].op_type == 'Squeeze' + assert optimized_model.graph.node[1].op_type == 'QuantizeLinear' + assert optimized_model.graph.node[2].op_type == 'DequantizeLinear' + assert optimized_model.graph.node[3].op_type == 'Conv' + assert optimized_model.graph.output[0].name == 'D' + def test_fuse_matmul_add_bias_into_gemm(self): # type: () -> None matmul = helper.make_node("MatMul", ["X", "Y"], ["Z"]) add = helper.make_node("Add", ["Z", "B"], ["A"]) diff --git a/third_party/onnx b/third_party/onnx index ed3845743..9eb78ecf9 160000 --- a/third_party/onnx +++ b/third_party/onnx @@ -1 +1 @@ -Subproject commit ed38457437bcc8aa2c6699ab202cf83db0234891 +Subproject commit 9eb78ecf9dda63c8f4a994247959a1f834d5e8f8