Skip to content

Commit

Permalink
add diag layer and its converter (#4935)
Browse files Browse the repository at this point in the history
  • Loading branch information
wnqn1597 authored Sep 8, 2023
1 parent 695f770 commit c851231
Show file tree
Hide file tree
Showing 13 changed files with 429 additions and 0 deletions.
12 changes: 12 additions & 0 deletions docs/developer-guide/operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
* [DeconvolutionDepthWise3D](#deconvolutiondepthwise3d)
* [DeformableConv2D](#deformableconv2d)
* [Dequantize](#dequantize)
* [Diag](#diag)
* [Dropout](#dropout)
* [Eltwise](#eltwise)
* [ELU](#elu)
Expand Down Expand Up @@ -749,6 +750,17 @@ y = x * scale + bias
| scale_data | float | [scale_data_size] |
| bias_data | float | [bias_data_size] |

# Diag
```
y = diag(x, diagonal)
```

* one_blob_only

| param id | name | type | default | description |
| --------- | ------------- | ----- | --------- | ----------------- |
| 0 | diagonal | int | 0 | |

# Dropout
```
y = x * scale
Expand Down
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ ncnn_add_layer(GridSample)
ncnn_add_layer(CumulativeSum)
ncnn_add_layer(CopyTo)
ncnn_add_layer(Erf)
ncnn_add_layer(Diag)

if(NCNN_VULKAN)
ncnn_add_shader(${CMAKE_CURRENT_SOURCE_DIR}/convert_ycbcr.comp)
Expand Down
91 changes: 91 additions & 0 deletions src/layer/diag.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "diag.h"

namespace ncnn {

Diag::Diag()
{
one_blob_only = true;
support_inplace = false;
}

int Diag::load_param(const ParamDict& pd)
{
diagonal = pd.get(0, 0);

return 0;
}

int Diag::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
{
int dims = bottom_blob.dims;
size_t elemsize = bottom_blob.elemsize;

if (dims == 1)
{
int w = bottom_blob.w;
int top_w = w + ((diagonal >= 0) ? diagonal : -diagonal);

top_blob.create(top_w, top_w, elemsize, opt.blob_allocator);
if (top_blob.empty())
return -100;

top_blob.fill(0.0f);

int bias_r = -std::min(diagonal, 0);
int bias_c = std::max(diagonal, 0);

for (int i = 0; i < w; i++)
{
top_blob.row(i + bias_r)[i + bias_c] = bottom_blob[i];
}
}
if (dims == 2)
{
int w = bottom_blob.w;
int h = bottom_blob.h;

int len = 0;
int minimum = std::min(w - h, 0);
int maximum = std::max(w - h, 0);
if (diagonal <= maximum && diagonal >= minimum)
len = std::min(w, h);
else if (diagonal > -h && diagonal < minimum)
len = diagonal + h;
else if (diagonal > maximum && diagonal < w)
len = -diagonal + w;

top_blob.create(len, elemsize, opt.blob_allocator);
if (top_blob.empty())
{
if (len == 0)
return 0;
return -100;
}

int bias_r = -std::min(diagonal, 0);
int bias_c = std::max(diagonal, 0);

for (int i = 0; i < len; i++)
{
top_blob[i] = bottom_blob.row(i + bias_r)[i + bias_c];
}
}

return 0;
}

} // namespace ncnn
37 changes: 37 additions & 0 deletions src/layer/diag.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#ifndef LAYER_DIAG_H
#define LAYER_DIAG_H

#include "layer.h"

namespace ncnn {

class Diag : public Layer
{
public:
Diag();

virtual int load_param(const ParamDict& pd);

virtual int forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const;

public:
int diagonal;
};

} // namespace ncnn

#endif // LAYER_DIAG_H
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ ncnn_add_layer_test(DeconvolutionDepthWise3D)
ncnn_add_layer_test(DeepCopy)
ncnn_add_layer_test(DeformableConv2D)
ncnn_add_layer_test(Dequantize)
ncnn_add_layer_test(Diag)
ncnn_add_layer_test(Dropout)
ncnn_add_layer_test(Einsum)
ncnn_add_layer_test(Eltwise)
Expand Down
58 changes: 58 additions & 0 deletions tests/test_diag.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "layer/diag.h"
#include "testutil.h"

static int test_diag(const ncnn::Mat& a, int diagonal)
{
ncnn::ParamDict pd;
pd.set(0, diagonal);

std::vector<ncnn::Mat> weights(0);

int ret = test_layer<ncnn::Diag>("Diag", pd, weights, a);
if (ret != 0)
{
fprintf(stderr, "test_diag failed a.dims=%d a=(%d %d %d %d)\n", a.dims, a.w, a.h, a.d, a.c);
}

return ret;
}

static int test_diag_0()
{
return 0
|| test_diag(RandomMat(5, 24), 3)
|| test_diag(RandomMat(7, 12), 0)
|| test_diag(RandomMat(6, 6), -4)
|| test_diag(RandomMat(3, 4), -6);
}

static int test_diag_1()
{
return 0
|| test_diag(RandomMat(5), -1)
|| test_diag(RandomMat(7), 0)
|| test_diag(RandomMat(3), 2);
}

int main()
{
SRAND(7767517);

return 0
|| test_diag_0()
|| test_diag_1();
}
2 changes: 2 additions & 0 deletions tools/pnnx/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ set(pnnx_pass_level2_SRCS
pass_level2/torch_cumprod.cpp
pass_level2/torch_cumsum.cpp
pass_level2/torch_dequantize.cpp
pass_level2/torch_diag.cpp
pass_level2/torch_einsum.cpp
pass_level2/torch_empty.cpp
pass_level2/torch_empty_like.cpp
Expand Down Expand Up @@ -534,6 +535,7 @@ set(pnnx_pass_ncnn_SRCS
pass_ncnn/torch_clamp.cpp
pass_ncnn/torch_clone.cpp
pass_ncnn/torch_cumsum.cpp
pass_ncnn/torch_diag.cpp
pass_ncnn/torch_flatten.cpp
pass_ncnn/torch_logsumexp.cpp
pass_ncnn/torch_matmul.cpp
Expand Down
41 changes: 41 additions & 0 deletions tools/pnnx/src/pass_level2/torch_diag.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level2.h"

namespace pnnx {

class torch_diag : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 diagonal
aten::diag op_0 2 1 input diagonal out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "torch.diag";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_diag, 20)

} // namespace pnnx
63 changes: 63 additions & 0 deletions tools/pnnx/src/pass_ncnn/torch_diag.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_ncnn.h"

namespace pnnx {

namespace ncnn {

class torch_diag : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
torch.diag op_0 1 1 input out diagonal=%diagonal
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "Diag";
}

const char* name_str() const
{
return "diag";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
int diagonal = captured_params.at("diagonal").i;
int input_rank = op->inputs[0]->shape.size();

if (input_rank > 2)
{
fprintf(stderr, "diag %d-rank tensor is not supported yet!\n", input_rank);
return;
}

op->params["0"] = diagonal;
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_diag, 20)

} // namespace ncnn

} // namespace pnnx
1 change: 1 addition & 0 deletions tools/pnnx/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ pnnx_add_test(torch_cumprod)
pnnx_add_test(torch_cumsum)
pnnx_add_test(torch_einsum)
pnnx_add_test(torch_eq)
pnnx_add_test(torch_diag)
pnnx_add_test(torch_flatten)
pnnx_add_test(torch_full)
pnnx_add_test(torch_full_like)
Expand Down
1 change: 1 addition & 0 deletions tools/pnnx/tests/ncnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ pnnx_ncnn_add_test(torch_cat)
pnnx_ncnn_add_test(torch_chunk)
pnnx_ncnn_add_test(torch_clone)
pnnx_ncnn_add_test(torch_cumsum)
pnnx_ncnn_add_test(torch_diag)
pnnx_ncnn_add_test(torch_einsum)
pnnx_ncnn_add_test(torch_logsumexp)
pnnx_ncnn_add_test(torch_matmul)
Expand Down
Loading

0 comments on commit c851231

Please sign in to comment.