From 070a6d40f27525427dd1c12153019a21f8fe9ac4 Mon Sep 17 00:00:00 2001 From: WXB <64680548+XiaBing992@users.noreply.github.com> Date: Mon, 14 Aug 2023 15:46:31 +0800 Subject: [PATCH] support torch.t to ncnn (#4940) --- tools/pnnx/src/CMakeLists.txt | 1 + tools/pnnx/src/pass_ncnn/torch_t.cpp | 54 ++++++++++++++++++++++++ tools/pnnx/tests/ncnn/CMakeLists.txt | 1 + tools/pnnx/tests/ncnn/test_torch_t.py | 61 +++++++++++++++++++++++++++ 4 files changed, 117 insertions(+) create mode 100644 tools/pnnx/src/pass_ncnn/torch_t.cpp create mode 100644 tools/pnnx/tests/ncnn/test_torch_t.py diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index 4c5a1e15cd63..bfbc2fa66116 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -537,6 +537,7 @@ set(pnnx_pass_ncnn_SRCS pass_ncnn/torch_prod.cpp pass_ncnn/torch_squeeze.cpp pass_ncnn/torch_sum.cpp + pass_ncnn/torch_t.cpp pass_ncnn/torch_transpose.cpp pass_ncnn/torch_unsqueeze.cpp pass_ncnn/torchvision_DeformConv2d.cpp diff --git a/tools/pnnx/src/pass_ncnn/torch_t.cpp b/tools/pnnx/src/pass_ncnn/torch_t.cpp new file mode 100644 index 000000000000..fa18297d381b --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/torch_t.cpp @@ -0,0 +1,54 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class torch_t : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +torch.t op_0 1 1 input out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Permute"; + } + + const char* name_str() const + { + return "t"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["0"] = 1; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_t, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/tests/ncnn/CMakeLists.txt b/tools/pnnx/tests/ncnn/CMakeLists.txt index 945576bfaf6d..208c6f18b4b2 100644 --- a/tools/pnnx/tests/ncnn/CMakeLists.txt +++ b/tools/pnnx/tests/ncnn/CMakeLists.txt @@ -153,6 +153,7 @@ pnnx_ncnn_add_test(torch_norm) pnnx_ncnn_add_test(torch_permute) pnnx_ncnn_add_test(torch_prod) pnnx_ncnn_add_test(torch_sum) +pnnx_ncnn_add_test(torch_t) pnnx_ncnn_add_test(torch_squeeze) pnnx_ncnn_add_test(torch_stack) pnnx_ncnn_add_test(torch_tensor_split) diff --git a/tools/pnnx/tests/ncnn/test_torch_t.py b/tools/pnnx/tests/ncnn/test_torch_t.py new file mode 100644 index 000000000000..2ade16a38752 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_torch_t.py @@ -0,0 +1,61 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y): + x = torch.t(x) + y = torch.t(y) + x = F.relu(x) + y = F.relu(y) + return x, y + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(3) + y = torch.rand(5, 9) + + a = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_torch_t.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_torch_t.pt inputshape=[3],[5,9]") + + # ncnn inference + import test_torch_t_ncnn + b = test_torch_t_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1)