Skip to content

Commit

Permalink
Merge branch 'improve-vit' of https://github.com/tpoisonooo/mmdeploy
Browse files Browse the repository at this point in the history
…into improve-vit
  • Loading branch information
tpoisonooo committed May 30, 2022
2 parents 9634189 + 1a14676 commit cf3a54b
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 15 deletions.
2 changes: 1 addition & 1 deletion csrc/backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1304,7 +1304,7 @@ int main(int argc, char** argv) {
}
fprintf(pp, " 0=%d", axis);
} else if (op == "Gelu") {
fprintf(pp, " 0=0");
fprintf(pp, " 0=1");
} else if (op == "Gemm") {
float alpha = get_node_attr_f(node, "alpha", 1.f);
float beta = get_node_attr_f(node, "beta", 1.f);
Expand Down
19 changes: 5 additions & 14 deletions mmdeploy/pytorch/ops/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,8 @@ def linear_no_bias(g, input, weight):
PyTorch `nn.Linear` will be exported as ONNX node 'Gemm'.
"""
g.op(
'mmdeploy::Gemm',
input,
weight,
alpha_f=1.0,
beta_f=1.0,
transA_i=0,
transB_i=1)
return g.op(
'Gemm', input, weight, alpha_f=1.0, beta_f=1.0, transA_i=0, transB_i=1)


@parse_args('v', 'v', 'v', 'f', 'f', 'i', 'i')
Expand All @@ -29,8 +23,8 @@ def linear_normal(g, input, weight, bias):
PyTorch `nn.Linear` will be exported as ONNX node 'Gemm'.
"""
g.op(
'mmdeploy::Gemm',
return g.op(
'Gemm',
input,
weight,
bias,
Expand All @@ -41,10 +35,7 @@ def linear_normal(g, input, weight, bias):


@SYMBOLIC_REWRITER.register_symbolic(
'torch.nn.functional.linear',
is_pytorch=True,
# arg_descriptors=['v', 'v', 'v', 'f', 'f', 'i', 'i'],
backend=Backend.NCNN.value)
'linear', is_pytorch=True, backend=Backend.NCNN.value)
def linear__ncnn(ctx, g, input, weight, bias):
"""Support export linear This rewrite enable export Gemm."""
if bias is None:
Expand Down

0 comments on commit cf3a54b

Please sign in to comment.