Skip to content

Commit

Permalink
improvement(ViT): use Crop to subtitude Gather (#477)
Browse files Browse the repository at this point in the history
* improvement(ViT): use Crop to subtitude Gather

* fix(CI): code format

* fix(pytorch/ops/linear.py): bias maybe None

* fix(test/test_pytorch_ops.py): op_type error

* fix(test): pytest error

* fix(test): torch version 1.8
  • Loading branch information
tpoisonooo authored Jun 2, 2022
1 parent ee878b5 commit cd336ea
Show file tree
Hide file tree
Showing 9 changed files with 147 additions and 4 deletions.
38 changes: 38 additions & 0 deletions csrc/backend_ops/ncnn/onnx2ncnn/fuse_pass.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,44 @@
// Copyright (c) OpenMMLab. All rights reserved.
#include "fuse_pass.h"

void fuse_rewrite_gather(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,
std::set<std::string>& blob_names, int& reduced_node_count) {
const int node_count = mutable_graph->node_size();
for (int i = 0; i < node_count; ++i) {
onnx::NodeProto* gather = mutable_graph->mutable_node(i);
if (gather->op_type() != "Gather") {
continue;
}
auto indices = get_node_attr_from_input_ai(weights[gather->input(1)]);
if (indices.size() != 1) {
continue;
}

{
// reconstruct node connections
node_reference[gather->input(1)] -= 1;
std::string origin_inp = gather->input(0);
gather->clear_input();
gather->add_input(origin_inp);
}

{
// update axis, starts and ends
int axis = get_node_attr_i(*gather, "axis", 1) - 1;

gather->set_op_type("Crop");
gather->clear_attribute();

int indice = indices[0];
set_node_attr_ai(*gather, "starts", std::vector<int>{indice});
set_node_attr_ai(*gather, "ends", std::vector<int>{indice + 1});
set_node_attr_ai(*gather, "axis", std::vector<int>{axis});
}
}
}

void fuse_weight_reshape(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,
Expand Down
5 changes: 5 additions & 0 deletions csrc/backend_ops/ncnn/onnx2ncnn/fuse_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
#include "shape_inference.h"
#include "utils.h"

void fuse_rewrite_gather(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,
std::set<std::string>& blob_names, int& reduced_node_count);

void fuse_weight_reshape(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,
Expand Down
21 changes: 20 additions & 1 deletion csrc/backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ int main(int argc, char** argv) {
fuse_multiheadattention(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
fuse_binaryop_with_scalar(mutable_graph, weights, node_reference, blob_names,
reduced_node_count);
fuse_rewrite_gather(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
}

// reduce common const weight node_reference
Expand Down Expand Up @@ -623,6 +624,8 @@ int main(int argc, char** argv) {
}
} else if (op == "Cos") {
fprintf(pp, "%-16s", "UnaryOp");
} else if (op == "Crop") {
fprintf(pp, "%-16s", "Crop");
} else if (op == "DepthToSpace") {
fprintf(pp, "%-16s", "PixelShuffle");
} else if (op == "DetectionOutput") {
Expand Down Expand Up @@ -1196,6 +1199,22 @@ int main(int argc, char** argv) {
} else if (op == "Cos") {
int op_type = 10;
fprintf(pp, " 0=%d", op_type);
} else if (op == "Crop") {
auto starts = get_node_attr_ai(node, "starts");
fprintf(pp, " -23309=%zu", starts.size());
for (size_t j = 0; j < starts.size(); ++j) {
fprintf(pp, ",%i", starts[j]);
}
auto ends = get_node_attr_ai(node, "ends");
fprintf(pp, " -23310=%zu", ends.size());
for (size_t j = 0; j < ends.size(); ++j) {
fprintf(pp, ",%i", ends[j]);
}
auto axis = get_node_attr_ai(node, "axis");
fprintf(pp, " -23311=%zu", axis.size());
for (size_t j = 0; j < axis.size(); ++j) {
fprintf(pp, ",%i", axis[j]);
}
} else if (op == "DepthToSpace") {
// pixelshuffle
int scale_factor = get_node_attr_i(node, "blocksize", 1);
Expand Down Expand Up @@ -1287,7 +1306,7 @@ int main(int argc, char** argv) {
}
fprintf(pp, " 0=%d", axis);
} else if (op == "Gelu") {
fprintf(pp, " 0=0");
fprintf(pp, " 0=1");
} else if (op == "Gemm") {
float alpha = get_node_attr_f(node, "alpha", 1.f);
float beta = get_node_attr_f(node, "beta", 1.f);
Expand Down
2 changes: 2 additions & 0 deletions csrc/backend_ops/ncnn/onnx2ncnn/shape_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

#include "shape_inference.h"

#include <algorithm>

/**
* @brief query output shape of target node
*
Expand Down
1 change: 0 additions & 1 deletion csrc/backend_ops/ncnn/onnx2ncnn/shape_inference.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once
#include <algorithm>

#include "utils.h"

Expand Down
2 changes: 1 addition & 1 deletion mmdeploy/pytorch/functions/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def linear__ncnn(

dim = input.dim()

if dim == 2:
if dim == 2 or dim == 3 and input.shape[0] == 1:
return origin_func(input, weight, bias)
else:
out = origin_func(input, weight)
Expand Down
3 changes: 2 additions & 1 deletion mmdeploy/pytorch/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .hardsigmoid import hardsigmoid__default
from .instance_norm import instance_norm__tensorrt
from .layer_norm import layer_norm__ncnn
from .linear import linear__ncnn
from .lstm import generic_rnn__ncnn
from .squeeze import squeeze__default

Expand All @@ -16,5 +17,5 @@
'adaptive_avg_pool3d__default', 'grid_sampler__default',
'hardsigmoid__default', 'instance_norm__tensorrt', 'generic_rnn__ncnn',
'squeeze__default', 'adaptive_avg_pool2d__ncnn', 'gelu__ncnn',
'layer_norm__ncnn'
'layer_norm__ncnn', 'linear__ncnn'
]
44 changes: 44 additions & 0 deletions mmdeploy/pytorch/ops/linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Modified from:
# https://github.com/pytorch/pytorch/blob/9ade03959392e5a90b74261012de1d806cab2253/torch/onnx/symbolic_opset9.py
from torch.onnx.symbolic_helper import parse_args

from mmdeploy.core import SYMBOLIC_REWRITER
from mmdeploy.utils import Backend


@parse_args('v', 'v', 'f', 'f', 'i', 'i')
def linear_no_bias(g, input, weight):
"""Symbolic function for `linear` without bias.
PyTorch `nn.Linear` will be exported as ONNX node 'Gemm'.
"""
return g.op(
'Gemm', input, weight, alpha_f=1.0, beta_f=1.0, transA_i=0, transB_i=1)


@parse_args('v', 'v', 'v', 'f', 'f', 'i', 'i')
def linear_normal(g, input, weight, bias):
"""Symbolic function for `linear`.
PyTorch `nn.Linear` will be exported as ONNX node 'Gemm'.
"""
return g.op(
'Gemm',
input,
weight,
bias,
alpha_f=1.0,
beta_f=1.0,
transA_i=0,
transB_i=1)


@SYMBOLIC_REWRITER.register_symbolic(
'linear', is_pytorch=True, backend=Backend.NCNN.value)
def linear__ncnn(ctx, g, input, weight, bias):
"""Support export linear This rewrite enable export Gemm."""
if bias is None:
return linear_no_bias(g, input, weight)
else:
return linear_normal(g, input, weight, bias)
35 changes: 35 additions & 0 deletions tests/test_pytorch/test_pytorch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,41 @@ def test_instance_norm():
assert nodes[4].domain == 'mmdeploy'


@pytest.mark.usefixtures('prepare_symbolics_ncnn')
class TestLinear:

def check(self, nodes):
print(nodes)

from packaging.version import parse as version_parse
version = version_parse(torch.__version__)
target = 'Gemm'
if version.major <= 1 and version.minor <= 8:
target = 'MatMul'
exist = False
for node in nodes:
if node.op_type == target:
exist = True
break

assert exist is True

def test_normal(self):
x = torch.rand(1, 2, 3)
w = torch.rand(2, 3)
bias = torch.rand(2)
model = OpModel(torch.nn.functional.linear, w, bias).eval()
nodes = get_model_onnx_nodes(model, x)
self.check(nodes)

def test_no_bias(self):
x = torch.rand(1, 2, 3)
w = torch.rand(2, 3)
model = OpModel(torch.nn.functional.linear, w).eval()
nodes = get_model_onnx_nodes(model, x)
self.check(nodes)


@pytest.mark.usefixtures('prepare_symbolics')
class TestSqueeze:

Expand Down

0 comments on commit cd336ea

Please sign in to comment.