From 92c59b95757d6605cefa698891b5b3805b14e844 Mon Sep 17 00:00:00 2001 From: Beq Jal <447727944@qq.com> Date: Fri, 11 Aug 2023 23:47:44 +0800 Subject: [PATCH 1/9] modify cmakelist --- tools/pnnx/src/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index 4c5a1e15cd6..17e58120292 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -216,6 +216,7 @@ set(pnnx_pass_level2_SRCS pass_level2/torch_cross.cpp pass_level2/torch_cumsum.cpp pass_level2/torch_dequantize.cpp + pass_level2/torch_diag.cpp pass_level2/torch_einsum.cpp pass_level2/torch_empty.cpp pass_level2/torch_empty_like.cpp From 8869ec3939416296d5a925a6da7c0810ef23a373 Mon Sep 17 00:00:00 2001 From: Beq Jal <447727944@qq.com> Date: Sat, 12 Aug 2023 16:40:28 +0800 Subject: [PATCH 2/9] add convert diag to pnnx --- tools/pnnx/tests/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/tools/pnnx/tests/CMakeLists.txt b/tools/pnnx/tests/CMakeLists.txt index 26ea2005d28..1f4470982dd 100644 --- a/tools/pnnx/tests/CMakeLists.txt +++ b/tools/pnnx/tests/CMakeLists.txt @@ -197,6 +197,7 @@ pnnx_add_test(torch_cross) pnnx_add_test(torch_cumsum) pnnx_add_test(torch_einsum) pnnx_add_test(torch_eq) +pnnx_add_test(torch_diag) pnnx_add_test(torch_flatten) pnnx_add_test(torch_full) pnnx_add_test(torch_full_like) From 039e7382400211489cc5ce92b67defe762e5bf91 Mon Sep 17 00:00:00 2001 From: Beq Jal <447727944@qq.com> Date: Sat, 12 Aug 2023 16:42:58 +0800 Subject: [PATCH 3/9] add convert diag to pnnx --- tools/pnnx/src/pass_level2/torch_diag.cpp | 41 ++++++++++++++++ tools/pnnx/tests/test_torch_diag.py | 60 +++++++++++++++++++++++ 2 files changed, 101 insertions(+) create mode 100644 tools/pnnx/src/pass_level2/torch_diag.cpp create mode 100644 tools/pnnx/tests/test_torch_diag.py diff --git a/tools/pnnx/src/pass_level2/torch_diag.cpp b/tools/pnnx/src/pass_level2/torch_diag.cpp new file mode 100644 index 00000000000..3fb784b3cd9 --- /dev/null +++ b/tools/pnnx/src/pass_level2/torch_diag.cpp @@ -0,0 +1,41 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class torch_diag : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 diagonal +aten::diag op_0 2 1 input diagonal out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.diag"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_diag, 20) + +} // namespace pnnx diff --git a/tools/pnnx/tests/test_torch_diag.py b/tools/pnnx/tests/test_torch_diag.py new file mode 100644 index 00000000000..45eba9b6f5d --- /dev/null +++ b/tools/pnnx/tests/test_torch_diag.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = torch.diag(x, -1) + y = torch.diag(y) + z = torch.diag(z, 3) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(7) + y = torch.rand(5, 5) + z = torch.rand(4, 8) + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_torch_diag.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_torch_diag.pt inputshape=[7],[5,5],[4,8]") + + # pnnx inference + import test_torch_diag_pnnx + b = test_torch_diag_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) From e5eb7e4832d454dfdd68cf9a0e23ab6d6a776823 Mon Sep 17 00:00:00 2001 From: Beq Jal <447727944@qq.com> Date: Sun, 13 Aug 2023 19:37:33 +0800 Subject: [PATCH 4/9] add diag layer to ncnn --- src/CMakeLists.txt | 1 + src/layer/diag.cpp | 82 +++++++++++++++++++++++++ src/layer/diag.h | 37 +++++++++++ tests/CMakeLists.txt | 1 + tests/test_diag.cpp | 57 +++++++++++++++++ tools/pnnx/src/pass_ncnn/torch_diag.cpp | 63 +++++++++++++++++++ 6 files changed, 241 insertions(+) create mode 100644 src/layer/diag.cpp create mode 100644 src/layer/diag.h create mode 100644 tests/test_diag.cpp create mode 100644 tools/pnnx/src/pass_ncnn/torch_diag.cpp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 87dc19616ad..be1742ac6e6 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -71,6 +71,7 @@ ncnn_add_layer(Concat) ncnn_add_layer(Convolution) ncnn_add_layer(Crop) ncnn_add_layer(Deconvolution) +ncnn_add_layer(Diag) ncnn_add_layer(Dropout) ncnn_add_layer(Eltwise) ncnn_add_layer(ELU) diff --git a/src/layer/diag.cpp b/src/layer/diag.cpp new file mode 100644 index 00000000000..936b32b2366 --- /dev/null +++ b/src/layer/diag.cpp @@ -0,0 +1,82 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "diag.h" + +namespace ncnn { + +Diag::Diag() +{ + one_blob_only = true; + support_inplace = false; +} + +int Diag::load_param(const ParamDict& pd) +{ + diagonal = pd.get(0, 0); + + return 0; +} + +int Diag::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const +{ + int dims = bottom_blob.dims; + size_t elemsize = bottom_blob.elemsize; + + if (dims == 1) + { + int w = bottom_blob.w; + int top_w = w + std::abs(diagonal); + int stride = top_w + 1; + + top_blob.create(top_w, top_w, elemsize, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + top_blob.fill(0.0f); + + int bias_r = -std::min(diagonal, 0); + int bias_c = std::max(diagonal, 0); + + for (int i = 0; i < w; i++) + { + top_blob.row(i + bias_r)[i + bias_c] = bottom_blob[i]; + } + } + if (dims == 2) + { + int w = bottom_blob.w; + int h = bottom_blob.h; + float tmp = (w - h) / 2.0; + + int len = std::min(w, h) - (int)std::max(std::abs(diagonal - tmp) - std::abs(tmp), 0.0f); + len = std::max(len, 0); + + top_blob.create(len, elemsize, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + int bias_r = -std::min(diagonal, 0); + int bias_c = std::max(diagonal, 0); + + for (int i = 0; i < len; i++) + { + top_blob[i] = bottom_blob.row(i + bias_r)[i + bias_c]; + } + } + + return 0; +} + +} // namespace ncnn diff --git a/src/layer/diag.h b/src/layer/diag.h new file mode 100644 index 00000000000..5eaf5f38d5e --- /dev/null +++ b/src/layer/diag.h @@ -0,0 +1,37 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef LAYER_DIAG_H +#define LAYER_DIAG_H + +#include "layer.h" + +namespace ncnn { + +class Diag : public Layer +{ +public: + Diag(); + + virtual int load_param(const ParamDict& pd); + + virtual int forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; + +public: + int diagonal; +}; + +} // namespace ncnn + +#endif // LAYER_DIAG_H \ No newline at end of file diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 78ebd4de5f0..319be191fb5 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -92,6 +92,7 @@ ncnn_add_layer_test(DeconvolutionDepthWise3D) ncnn_add_layer_test(DeepCopy) ncnn_add_layer_test(DeformableConv2D) ncnn_add_layer_test(Dequantize) +ncnn_add_layer_test(Diag) ncnn_add_layer_test(Dropout) ncnn_add_layer_test(Einsum) ncnn_add_layer_test(Eltwise) diff --git a/tests/test_diag.cpp b/tests/test_diag.cpp new file mode 100644 index 00000000000..0846bd9052a --- /dev/null +++ b/tests/test_diag.cpp @@ -0,0 +1,57 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "layer/diag.h" +#include "testutil.h" + +static int test_diag(const ncnn::Mat& a, int diagonal) +{ + ncnn::ParamDict pd; + pd.set(0, diagonal); + + std::vector weights(0); + + int ret = test_layer("Diag", pd, weights, a); + if (ret != 0) + { + fprintf(stderr, "test_diag failed a.dims=%d a=(%d %d %d %d)\n", a.dims, a.w, a.h, a.d, a.c); + } + + return ret; +} + +static int test_diag_0() +{ + return 0 + || test_diag(RandomMat(5, 24), 3) + || test_diag(RandomMat(7, 12), 0) + || test_diag(RandomMat(3, 4), -6); +} + +static int test_diag_1() +{ + return 0 + || test_diag(RandomMat(5), -1) + || test_diag(RandomMat(7), 0) + || test_diag(RandomMat(3), 2); +} + +int main() +{ + SRAND(7767517); + + return 0 + || test_diag_0() + || test_diag_1(); +} diff --git a/tools/pnnx/src/pass_ncnn/torch_diag.cpp b/tools/pnnx/src/pass_ncnn/torch_diag.cpp new file mode 100644 index 00000000000..fd73ed32239 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/torch_diag.cpp @@ -0,0 +1,63 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class torch_diag : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +torch.diag op_0 1 1 input out diagonal=%diagonal +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Diag"; + } + + const char* name_str() const + { + return "diag"; + } + + void write(Operator* op, const std::map& captured_params) const + { + int diagonal = captured_params.at("diagonal").i; + int input_rank = op->inputs[0]->shape.size(); + + if (input_rank > 2) + { + fprintf(stderr, "diag %d-rank tensor is not supported yet!\n", input_rank); + return; + } + + op->params["0"] = diagonal; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_diag, 20) + +} // namespace ncnn + +} // namespace pnnx From 260e60978d5ca927c94fed845de4e73c25639593 Mon Sep 17 00:00:00 2001 From: Beq Jal <447727944@qq.com> Date: Mon, 14 Aug 2023 20:24:55 +0800 Subject: [PATCH 5/9] fix some bugs and alignment --- src/CMakeLists.txt | 2 +- tools/pnnx/src/pass_level2/torch_diag.cpp | 4 +- tools/pnnx/src/pass_ncnn/torch_diag.cpp | 2 +- tools/pnnx/tests/ncnn/test_torch_diag.py | 61 +++++++++++++++++++++++ 4 files changed, 65 insertions(+), 4 deletions(-) create mode 100644 tools/pnnx/tests/ncnn/test_torch_diag.py diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index be1742ac6e6..aa9cb23a279 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -71,7 +71,6 @@ ncnn_add_layer(Concat) ncnn_add_layer(Convolution) ncnn_add_layer(Crop) ncnn_add_layer(Deconvolution) -ncnn_add_layer(Diag) ncnn_add_layer(Dropout) ncnn_add_layer(Eltwise) ncnn_add_layer(ELU) @@ -163,6 +162,7 @@ ncnn_add_layer(Unfold) ncnn_add_layer(GridSample) ncnn_add_layer(CumulativeSum) ncnn_add_layer(CopyTo) +ncnn_add_layer(Diag) if(NCNN_VULKAN) ncnn_add_shader(${CMAKE_CURRENT_SOURCE_DIR}/convert_ycbcr.comp) diff --git a/tools/pnnx/src/pass_level2/torch_diag.cpp b/tools/pnnx/src/pass_level2/torch_diag.cpp index 3fb784b3cd9..6e78a2752aa 100644 --- a/tools/pnnx/src/pass_level2/torch_diag.cpp +++ b/tools/pnnx/src/pass_level2/torch_diag.cpp @@ -22,10 +22,10 @@ class torch_diag : public GraphRewriterPass const char* match_pattern_graph() const { return R"PNNXIR(7767517 -5 4 +4 3 pnnx.Input input_0 0 1 input pnnx.Input input_1 0 1 diagonal -aten::diag op_0 2 1 input diagonal out +aten::diag op_0 2 1 input diagonal out pnnx.Output output 1 0 out )PNNXIR"; } diff --git a/tools/pnnx/src/pass_ncnn/torch_diag.cpp b/tools/pnnx/src/pass_ncnn/torch_diag.cpp index fd73ed32239..35349f2cd1c 100644 --- a/tools/pnnx/src/pass_ncnn/torch_diag.cpp +++ b/tools/pnnx/src/pass_ncnn/torch_diag.cpp @@ -26,7 +26,7 @@ class torch_diag : public GraphRewriterPass return R"PNNXIR(7767517 3 2 pnnx.Input input 0 1 input -torch.diag op_0 1 1 input out diagonal=%diagonal +torch.diag op_0 1 1 input out diagonal=%diagonal pnnx.Output output 1 0 out )PNNXIR"; } diff --git a/tools/pnnx/tests/ncnn/test_torch_diag.py b/tools/pnnx/tests/ncnn/test_torch_diag.py new file mode 100644 index 00000000000..8af8069ef05 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_torch_diag.py @@ -0,0 +1,61 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = torch.diag(x, -1) + y = torch.diag(y, 0) + z = torch.diag(z, 3) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(7) + y = torch.rand(5, 5) + z = torch.rand(4, 8) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_torch_diag.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_torch_diag.pt inputshape=[7],[5,5],[4,8]") + + # ncnn inference + import test_torch_diag_ncnn + b = test_torch_diag_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) From 087e035c6d292cb63f41eed27c844e5543d3951f Mon Sep 17 00:00:00 2001 From: Beq Jal <447727944@qq.com> Date: Mon, 14 Aug 2023 20:59:13 +0800 Subject: [PATCH 6/9] update operators.md and modify cmakelist --- docs/developer-guide/operators.md | 12 ++++++++++++ tools/pnnx/tests/ncnn/CMakeLists.txt | 1 + 2 files changed, 13 insertions(+) diff --git a/docs/developer-guide/operators.md b/docs/developer-guide/operators.md index 17acf4ec03f..794833086ce 100644 --- a/docs/developer-guide/operators.md +++ b/docs/developer-guide/operators.md @@ -25,6 +25,7 @@ * [DeconvolutionDepthWise3D](#deconvolutiondepthwise3d) * [DeformableConv2D](#deformableconv2d) * [Dequantize](#dequantize) +* [Diag](#diag) * [Dropout](#dropout) * [Eltwise](#eltwise) * [ELU](#elu) @@ -749,6 +750,17 @@ y = x * scale + bias | scale_data | float | [scale_data_size] | | bias_data | float | [bias_data_size] | +# Diag +``` +y = diag(x, diagonal) +``` + +* one_blob_only + +| param id | name | type | default | description | +| --------- | ------------- | ----- | --------- | ----------------- | +| 0 | diagonal | int | 0 | | + # Dropout ``` y = x * scale diff --git a/tools/pnnx/tests/ncnn/CMakeLists.txt b/tools/pnnx/tests/ncnn/CMakeLists.txt index 945576bfaf6..2ebf27ad4a3 100644 --- a/tools/pnnx/tests/ncnn/CMakeLists.txt +++ b/tools/pnnx/tests/ncnn/CMakeLists.txt @@ -142,6 +142,7 @@ pnnx_ncnn_add_test(torch_cat) pnnx_ncnn_add_test(torch_chunk) pnnx_ncnn_add_test(torch_clone) pnnx_ncnn_add_test(torch_cumsum) +pnnx_ncnn_add_test(torch_diag) pnnx_ncnn_add_test(torch_einsum) pnnx_ncnn_add_test(torch_logsumexp) pnnx_ncnn_add_test(torch_matmul) From 191a00d8b036592d006e84e09f89bae412cf6db1 Mon Sep 17 00:00:00 2001 From: Beq Jal <447727944@qq.com> Date: Wed, 16 Aug 2023 16:31:11 +0800 Subject: [PATCH 7/9] fix the bug of std::abs --- src/layer/diag.cpp | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/layer/diag.cpp b/src/layer/diag.cpp index 936b32b2366..f6f9c64a2ac 100644 --- a/src/layer/diag.cpp +++ b/src/layer/diag.cpp @@ -37,8 +37,7 @@ int Diag::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) cons if (dims == 1) { int w = bottom_blob.w; - int top_w = w + std::abs(diagonal); - int stride = top_w + 1; + int top_w = w + ((diagonal >= 0) ? diagonal : -diagonal); top_blob.create(top_w, top_w, elemsize, opt.blob_allocator); if (top_blob.empty()) @@ -58,10 +57,16 @@ int Diag::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) cons { int w = bottom_blob.w; int h = bottom_blob.h; - float tmp = (w - h) / 2.0; - int len = std::min(w, h) - (int)std::max(std::abs(diagonal - tmp) - std::abs(tmp), 0.0f); - len = std::max(len, 0); + int len = 0; + int minimum = std::min(w - h, 0); + int maximum = std::max(w - h, 0); + if (diagonal <= maximum && diagonal >= minimum) + len = std::min(w, h); + else if (diagonal > -h && diagonal < minimum) + len = diagonal + h; + else if (diagonal > maximum && diagonal < w) + len = -diagonal + w; top_blob.create(len, elemsize, opt.blob_allocator); if (top_blob.empty()) From 42ad3481c42d283ff1799d2012bca09093d72ffa Mon Sep 17 00:00:00 2001 From: Beq Jal <447727944@qq.com> Date: Thu, 17 Aug 2023 21:24:46 +0800 Subject: [PATCH 8/9] fix a bug and modify test_diag --- src/layer/diag.cpp | 4 ++++ src/layer/diag.h | 2 +- tests/test_diag.cpp | 1 + 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/layer/diag.cpp b/src/layer/diag.cpp index f6f9c64a2ac..549a353d017 100644 --- a/src/layer/diag.cpp +++ b/src/layer/diag.cpp @@ -70,7 +70,11 @@ int Diag::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) cons top_blob.create(len, elemsize, opt.blob_allocator); if (top_blob.empty()) + { + if (len == 0) + return 0; return -100; + } int bias_r = -std::min(diagonal, 0); int bias_c = std::max(diagonal, 0); diff --git a/src/layer/diag.h b/src/layer/diag.h index 5eaf5f38d5e..8dd0babeaef 100644 --- a/src/layer/diag.h +++ b/src/layer/diag.h @@ -34,4 +34,4 @@ class Diag : public Layer } // namespace ncnn -#endif // LAYER_DIAG_H \ No newline at end of file +#endif // LAYER_DIAG_H diff --git a/tests/test_diag.cpp b/tests/test_diag.cpp index 0846bd9052a..bb192d78ccc 100644 --- a/tests/test_diag.cpp +++ b/tests/test_diag.cpp @@ -36,6 +36,7 @@ static int test_diag_0() return 0 || test_diag(RandomMat(5, 24), 3) || test_diag(RandomMat(7, 12), 0) + || test_diag(RandomMat(6, 6), -4) || test_diag(RandomMat(3, 4), -6); } From 04e0f6f07be08b1a78288d0ad06004f5d526f748 Mon Sep 17 00:00:00 2001 From: Beq Jal <447727944@qq.com> Date: Wed, 6 Sep 2023 16:26:52 +0800 Subject: [PATCH 9/9] fix bug --- tools/pnnx/src/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index 97be7a1d1a0..fc333757ad0 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -530,6 +530,7 @@ set(pnnx_pass_ncnn_SRCS pass_ncnn/torch_clamp.cpp pass_ncnn/torch_clone.cpp pass_ncnn/torch_cumsum.cpp + pass_ncnn/torch_diag.cpp pass_ncnn/torch_flatten.cpp pass_ncnn/torch_logsumexp.cpp pass_ncnn/torch_matmul.cpp