From 85d1dee0779e5463ed09cb1ec9812c237f1a84d3 Mon Sep 17 00:00:00 2001 From: tpoisonooo Date: Thu, 2 Jun 2022 19:39:15 +0800 Subject: [PATCH] improvement(ViT): use Crop to subtitude Gather (#477) * improvement(ViT): use Crop to subtitude Gather * fix(CI): code format * fix(pytorch/ops/linear.py): bias maybe None * fix(test/test_pytorch_ops.py): op_type error * fix(test): pytest error * fix(test): torch version 1.8 --- csrc/backend_ops/ncnn/onnx2ncnn/fuse_pass.cpp | 38 ++++++++++++++++ csrc/backend_ops/ncnn/onnx2ncnn/fuse_pass.h | 5 +++ csrc/backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp | 21 ++++++++- .../ncnn/onnx2ncnn/shape_inference.cpp | 2 + .../ncnn/onnx2ncnn/shape_inference.h | 1 - mmdeploy/pytorch/functions/linear.py | 2 +- mmdeploy/pytorch/ops/__init__.py | 3 +- mmdeploy/pytorch/ops/linear.py | 44 +++++++++++++++++++ tests/test_pytorch/test_pytorch_ops.py | 35 +++++++++++++++ 9 files changed, 147 insertions(+), 4 deletions(-) create mode 100644 mmdeploy/pytorch/ops/linear.py diff --git a/csrc/backend_ops/ncnn/onnx2ncnn/fuse_pass.cpp b/csrc/backend_ops/ncnn/onnx2ncnn/fuse_pass.cpp index 58727daa71..069c40f193 100644 --- a/csrc/backend_ops/ncnn/onnx2ncnn/fuse_pass.cpp +++ b/csrc/backend_ops/ncnn/onnx2ncnn/fuse_pass.cpp @@ -1,6 +1,44 @@ // Copyright (c) OpenMMLab. All rights reserved. #include "fuse_pass.h" +void fuse_rewrite_gather(onnx::GraphProto* mutable_graph, + std::map& weights, + std::map& node_reference, + std::set& blob_names, int& reduced_node_count) { + const int node_count = mutable_graph->node_size(); + for (int i = 0; i < node_count; ++i) { + onnx::NodeProto* gather = mutable_graph->mutable_node(i); + if (gather->op_type() != "Gather") { + continue; + } + auto indices = get_node_attr_from_input_ai(weights[gather->input(1)]); + if (indices.size() != 1) { + continue; + } + + { + // reconstruct node connections + node_reference[gather->input(1)] -= 1; + std::string origin_inp = gather->input(0); + gather->clear_input(); + gather->add_input(origin_inp); + } + + { + // update axis, starts and ends + int axis = get_node_attr_i(*gather, "axis", 1) - 1; + + gather->set_op_type("Crop"); + gather->clear_attribute(); + + int indice = indices[0]; + set_node_attr_ai(*gather, "starts", std::vector{indice}); + set_node_attr_ai(*gather, "ends", std::vector{indice + 1}); + set_node_attr_ai(*gather, "axis", std::vector{axis}); + } + } +} + void fuse_weight_reshape(onnx::GraphProto* mutable_graph, std::map& weights, std::map& node_reference, diff --git a/csrc/backend_ops/ncnn/onnx2ncnn/fuse_pass.h b/csrc/backend_ops/ncnn/onnx2ncnn/fuse_pass.h index 091ae15909..bdd2edd233 100644 --- a/csrc/backend_ops/ncnn/onnx2ncnn/fuse_pass.h +++ b/csrc/backend_ops/ncnn/onnx2ncnn/fuse_pass.h @@ -4,6 +4,11 @@ #include "shape_inference.h" #include "utils.h" +void fuse_rewrite_gather(onnx::GraphProto* mutable_graph, + std::map& weights, + std::map& node_reference, + std::set& blob_names, int& reduced_node_count); + void fuse_weight_reshape(onnx::GraphProto* mutable_graph, std::map& weights, std::map& node_reference, diff --git a/csrc/backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp b/csrc/backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp index 61e6e69f58..925c7a1757 100644 --- a/csrc/backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp +++ b/csrc/backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp @@ -229,6 +229,7 @@ int main(int argc, char** argv) { fuse_multiheadattention(mutable_graph, weights, node_reference, blob_names, reduced_node_count); fuse_binaryop_with_scalar(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_rewrite_gather(mutable_graph, weights, node_reference, blob_names, reduced_node_count); } // reduce common const weight node_reference @@ -623,6 +624,8 @@ int main(int argc, char** argv) { } } else if (op == "Cos") { fprintf(pp, "%-16s", "UnaryOp"); + } else if (op == "Crop") { + fprintf(pp, "%-16s", "Crop"); } else if (op == "DepthToSpace") { fprintf(pp, "%-16s", "PixelShuffle"); } else if (op == "DetectionOutput") { @@ -1196,6 +1199,22 @@ int main(int argc, char** argv) { } else if (op == "Cos") { int op_type = 10; fprintf(pp, " 0=%d", op_type); + } else if (op == "Crop") { + auto starts = get_node_attr_ai(node, "starts"); + fprintf(pp, " -23309=%zu", starts.size()); + for (size_t j = 0; j < starts.size(); ++j) { + fprintf(pp, ",%i", starts[j]); + } + auto ends = get_node_attr_ai(node, "ends"); + fprintf(pp, " -23310=%zu", ends.size()); + for (size_t j = 0; j < ends.size(); ++j) { + fprintf(pp, ",%i", ends[j]); + } + auto axis = get_node_attr_ai(node, "axis"); + fprintf(pp, " -23311=%zu", axis.size()); + for (size_t j = 0; j < axis.size(); ++j) { + fprintf(pp, ",%i", axis[j]); + } } else if (op == "DepthToSpace") { // pixelshuffle int scale_factor = get_node_attr_i(node, "blocksize", 1); @@ -1287,7 +1306,7 @@ int main(int argc, char** argv) { } fprintf(pp, " 0=%d", axis); } else if (op == "Gelu") { - fprintf(pp, " 0=0"); + fprintf(pp, " 0=1"); } else if (op == "Gemm") { float alpha = get_node_attr_f(node, "alpha", 1.f); float beta = get_node_attr_f(node, "beta", 1.f); diff --git a/csrc/backend_ops/ncnn/onnx2ncnn/shape_inference.cpp b/csrc/backend_ops/ncnn/onnx2ncnn/shape_inference.cpp index f2402d96e1..dd1fe2c4f6 100644 --- a/csrc/backend_ops/ncnn/onnx2ncnn/shape_inference.cpp +++ b/csrc/backend_ops/ncnn/onnx2ncnn/shape_inference.cpp @@ -2,6 +2,8 @@ #include "shape_inference.h" +#include + /** * @brief query output shape of target node * diff --git a/csrc/backend_ops/ncnn/onnx2ncnn/shape_inference.h b/csrc/backend_ops/ncnn/onnx2ncnn/shape_inference.h index 38b3365a80..fa62ffe9de 100644 --- a/csrc/backend_ops/ncnn/onnx2ncnn/shape_inference.h +++ b/csrc/backend_ops/ncnn/onnx2ncnn/shape_inference.h @@ -1,7 +1,6 @@ // Copyright (c) OpenMMLab. All rights reserved. #pragma once -#include #include "utils.h" diff --git a/mmdeploy/pytorch/functions/linear.py b/mmdeploy/pytorch/functions/linear.py index a78708f840..d919cc803c 100644 --- a/mmdeploy/pytorch/functions/linear.py +++ b/mmdeploy/pytorch/functions/linear.py @@ -25,7 +25,7 @@ def linear__ncnn( dim = input.dim() - if dim == 2: + if dim == 2 or dim == 3 and input.shape[0] == 1: return origin_func(input, weight, bias) else: out = origin_func(input, weight) diff --git a/mmdeploy/pytorch/ops/__init__.py b/mmdeploy/pytorch/ops/__init__.py index 0836174713..77da009265 100644 --- a/mmdeploy/pytorch/ops/__init__.py +++ b/mmdeploy/pytorch/ops/__init__.py @@ -8,6 +8,7 @@ from .hardsigmoid import hardsigmoid__default from .instance_norm import instance_norm__tensorrt from .layer_norm import layer_norm__ncnn +from .linear import linear__ncnn from .lstm import generic_rnn__ncnn from .squeeze import squeeze__default @@ -16,5 +17,5 @@ 'adaptive_avg_pool3d__default', 'grid_sampler__default', 'hardsigmoid__default', 'instance_norm__tensorrt', 'generic_rnn__ncnn', 'squeeze__default', 'adaptive_avg_pool2d__ncnn', 'gelu__ncnn', - 'layer_norm__ncnn' + 'layer_norm__ncnn', 'linear__ncnn' ] diff --git a/mmdeploy/pytorch/ops/linear.py b/mmdeploy/pytorch/ops/linear.py new file mode 100644 index 0000000000..8cb997b400 --- /dev/null +++ b/mmdeploy/pytorch/ops/linear.py @@ -0,0 +1,44 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Modified from: +# https://github.com/pytorch/pytorch/blob/9ade03959392e5a90b74261012de1d806cab2253/torch/onnx/symbolic_opset9.py +from torch.onnx.symbolic_helper import parse_args + +from mmdeploy.core import SYMBOLIC_REWRITER +from mmdeploy.utils import Backend + + +@parse_args('v', 'v', 'f', 'f', 'i', 'i') +def linear_no_bias(g, input, weight): + """Symbolic function for `linear` without bias. + + PyTorch `nn.Linear` will be exported as ONNX node 'Gemm'. + """ + return g.op( + 'Gemm', input, weight, alpha_f=1.0, beta_f=1.0, transA_i=0, transB_i=1) + + +@parse_args('v', 'v', 'v', 'f', 'f', 'i', 'i') +def linear_normal(g, input, weight, bias): + """Symbolic function for `linear`. + + PyTorch `nn.Linear` will be exported as ONNX node 'Gemm'. + """ + return g.op( + 'Gemm', + input, + weight, + bias, + alpha_f=1.0, + beta_f=1.0, + transA_i=0, + transB_i=1) + + +@SYMBOLIC_REWRITER.register_symbolic( + 'linear', is_pytorch=True, backend=Backend.NCNN.value) +def linear__ncnn(ctx, g, input, weight, bias): + """Support export linear This rewrite enable export Gemm.""" + if bias is None: + return linear_no_bias(g, input, weight) + else: + return linear_normal(g, input, weight, bias) diff --git a/tests/test_pytorch/test_pytorch_ops.py b/tests/test_pytorch/test_pytorch_ops.py index 0b818eaabd..841e3ea7e2 100644 --- a/tests/test_pytorch/test_pytorch_ops.py +++ b/tests/test_pytorch/test_pytorch_ops.py @@ -127,6 +127,41 @@ def test_instance_norm(): assert nodes[4].domain == 'mmdeploy' +@pytest.mark.usefixtures('prepare_symbolics_ncnn') +class TestLinear: + + def check(self, nodes): + print(nodes) + + from packaging.version import parse as version_parse + version = version_parse(torch.__version__) + target = 'Gemm' + if version.major <= 1 and version.minor <= 8: + target = 'MatMul' + exist = False + for node in nodes: + if node.op_type == target: + exist = True + break + + assert exist is True + + def test_normal(self): + x = torch.rand(1, 2, 3) + w = torch.rand(2, 3) + bias = torch.rand(2) + model = OpModel(torch.nn.functional.linear, w, bias).eval() + nodes = get_model_onnx_nodes(model, x) + self.check(nodes) + + def test_no_bias(self): + x = torch.rand(1, 2, 3) + w = torch.rand(2, 3) + model = OpModel(torch.nn.functional.linear, w).eval() + nodes = get_model_onnx_nodes(model, x) + self.check(nodes) + + @pytest.mark.usefixtures('prepare_symbolics') class TestSqueeze: