Skip to content

Commit

Permalink
pnnx convert fold unfold (#4325)
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui authored Nov 2, 2022
1 parent b8d40a9 commit a12c24d
Show file tree
Hide file tree
Showing 13 changed files with 445 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .ci/pnnx.yml
Original file line number Diff line number Diff line change
Expand Up @@ -116,4 +116,4 @@ jobs:
export MKL_NUM_THREADS=1
export MKL_ENABLE_INSTRUCTIONS=SSE4_2
cd tools/pnnx
cd build && ctest --output-on-failure -j $(nproc)
cd build && ctest --output-on-failure -j 16
8 changes: 4 additions & 4 deletions tools/pnnx/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ TORCH_LIBRARY(upfirdn2d_op, m) {
|nn.Embedding | :heavy_check_mark: | :heavy_check_mark: |
|nn.EmbeddingBag | |
|nn.Flatten | :heavy_check_mark: |
|nn.Fold | |
|nn.Fold | :heavy_check_mark: |
|nn.FractionalMaxPool2d | |
|nn.FractionalMaxPool3d | |
|nn.GELU | :heavy_check_mark: | :heavy_check_mark: |
Expand Down Expand Up @@ -562,7 +562,7 @@ TORCH_LIBRARY(upfirdn2d_op, m) {
|nn.TransformerEncoder | |
|nn.TransformerEncoderLayer | |
|nn.Unflatten | |
|nn.Unfold | |
|nn.Unfold | :heavy_check_mark: |
|nn.Upsample | :heavy_check_mark: | :heavy_check_mark: |
|nn.UpsamplingBilinear2d | :heavy_check_mark: | :heavy_check_mark: |
|nn.UpsamplingNearest2d | :heavy_check_mark: | :heavy_check_mark: |
Expand Down Expand Up @@ -600,7 +600,7 @@ TORCH_LIBRARY(upfirdn2d_op, m) {
|F.embedding | :heavy_check_mark: | :heavy_check_mark: |
|F.embedding_bag | |
|F.feature_alpha_dropout | :heavy_check_mark: | :heavy_check_mark: |
|F.fold | |
|F.fold | :heavy_check_mark: |
|F.fractional_max_pool2d | |
|F.fractional_max_pool3d | |
|F.gelu | :heavy_check_mark: | :heavy_check_mark: |
Expand Down Expand Up @@ -656,7 +656,7 @@ TORCH_LIBRARY(upfirdn2d_op, m) {
|F.tanhshrink | :heavy_check_mark: |
|F.threshold | :heavy_check_mark: |
|F.threshold_ | :heavy_check_mark: |
|F.unfold | |
|F.unfold | :heavy_check_mark: |
|F.upsample | :heavy_check_mark: | :heavy_check_mark: |
|F.upsample_bilinear | :heavy_check_mark: | :heavy_check_mark: |
|F.upsample_nearest | :heavy_check_mark: | :heavy_check_mark: |
4 changes: 4 additions & 0 deletions tools/pnnx/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ set(pnnx_pass_level1_SRCS
pass_level1/nn_Dropout3d.cpp
pass_level1/nn_ELU.cpp
pass_level1/nn_Embedding.cpp
pass_level1/nn_Fold.cpp
pass_level1/nn_GELU.cpp
pass_level1/nn_GLU.cpp
pass_level1/nn_GroupNorm.cpp
Expand Down Expand Up @@ -87,6 +88,7 @@ set(pnnx_pass_level1_SRCS
pass_level1/nn_Tanh.cpp
pass_level1/nn_Tanhshrink.cpp
pass_level1/nn_Threshold.cpp
pass_level1/nn_Unfold.cpp
pass_level1/nn_Upsample.cpp
pass_level1/nn_UpsamplingBilinear2d.cpp
pass_level1/nn_UpsamplingNearest2d.cpp
Expand Down Expand Up @@ -124,6 +126,7 @@ set(pnnx_pass_level2_SRCS
pass_level2/F_elu.cpp
pass_level2/F_embedding.cpp
pass_level2/F_feature_alpha_dropout.cpp
pass_level2/F_fold.cpp
pass_level2/F_gelu.cpp
pass_level2/F_glu.cpp
pass_level2/F_grid_sample.cpp
Expand Down Expand Up @@ -165,6 +168,7 @@ set(pnnx_pass_level2_SRCS
pass_level2/F_tanh.cpp
pass_level2/F_tanhshrink.cpp
pass_level2/F_threshold.cpp
pass_level2/F_unfold.cpp
pass_level2/F_upsample_bilinear.cpp
pass_level2/F_upsample_nearest.cpp
pass_level2/F_upsample.cpp
Expand Down
48 changes: 48 additions & 0 deletions tools/pnnx/src/pass_level1/nn_Fold.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"

namespace pnnx {

class Fold : public FuseModulePass
{
public:
const char* match_type_str() const
{
return "__torch__.torch.nn.modules.fold.Fold";
}

const char* type_str() const
{
return "nn.Fold";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const
{
const torch::jit::Node* col2im = find_node_by_kind(graph, "aten::col2im");

op->params["output_size"] = col2im->namedInput("output_size");
op->params["kernel_size"] = col2im->namedInput("kernel_size");
op->params["stride"] = col2im->namedInput("stride");
op->params["padding"] = col2im->namedInput("padding");
op->params["dilation"] = col2im->namedInput("dilation");
}
};

REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(Fold)

} // namespace pnnx
47 changes: 47 additions & 0 deletions tools/pnnx/src/pass_level1/nn_Unfold.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"

namespace pnnx {

class Unfold : public FuseModulePass
{
public:
const char* match_type_str() const
{
return "__torch__.torch.nn.modules.fold.Unfold";
}

const char* type_str() const
{
return "nn.Unfold";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const
{
const torch::jit::Node* im2col = find_node_by_kind(graph, "aten::im2col");

op->params["kernel_size"] = im2col->namedInput("kernel_size");
op->params["stride"] = im2col->namedInput("stride");
op->params["padding"] = im2col->namedInput("padding");
op->params["dilation"] = im2col->namedInput("dilation");
}
};

REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(Unfold)

} // namespace pnnx
45 changes: 45 additions & 0 deletions tools/pnnx/src/pass_level2/F_fold.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level2.h"

namespace pnnx {

class F_fold : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
8 7
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 output_size
pnnx.Input input_2 0 1 kernel_size
pnnx.Input input_3 0 1 dilation
pnnx.Input input_4 0 1 padding
pnnx.Input input_5 0 1 stride
aten::col2im op_0 6 1 input output_size kernel_size dilation padding stride out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "F.fold";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_fold, 10)

} // namespace pnnx
44 changes: 44 additions & 0 deletions tools/pnnx/src/pass_level2/F_unfold.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level2.h"

namespace pnnx {

class F_unfold : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
7 6
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 kernel_size
pnnx.Input input_2 0 1 dilation
pnnx.Input input_3 0 1 padding
pnnx.Input input_4 0 1 stride
aten::im2col op_0 5 1 input kernel_size dilation padding stride out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "F.unfold";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_unfold, 10)

} // namespace pnnx
4 changes: 4 additions & 0 deletions tools/pnnx/src/pass_ncnn/solve_batch_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ static bool is_known_operator_with_batch_index_0(const Operator* op)
"F.conv1d",
"F.conv2d",
"F.conv3d",
"F.fold",
"F.grid_sample",
"F.group_norm",
"F.instance_norm",
Expand All @@ -54,6 +55,7 @@ static bool is_known_operator_with_batch_index_0(const Operator* op)
"F.pixel_shuffle",
"F.pixel_unshuffle",
"F.prelu",
"F.unfold",
"F.upsample_bilinear",
"F.upsample_nearest",
"F.upsample",
Expand All @@ -80,6 +82,7 @@ static bool is_known_operator_with_batch_index_0(const Operator* op)
"nn.ConvTranspose1d",
"nn.ConvTranspose2d",
"nn.ConvTranspose3d",
"nn.Fold",
"nn.GroupNorm",
"nn.InstanceNorm1d",
"nn.InstanceNorm2d",
Expand All @@ -99,6 +102,7 @@ static bool is_known_operator_with_batch_index_0(const Operator* op)
"nn.ReplicationPad2d",
"nn.ReplicationPad3d",
"nn.Softmax2d",
"nn.Unfold",
"nn.Upsample",
"nn.UpsamplingBilinear2d",
"nn.UpsamplingNearest2d",
Expand Down
4 changes: 4 additions & 0 deletions tools/pnnx/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pnnx_add_test(F_dropout3d)
pnnx_add_test(F_elu)
pnnx_add_test(F_embedding)
pnnx_add_test(F_feature_alpha_dropout)
pnnx_add_test(F_fold)
pnnx_add_test(F_gelu)
pnnx_add_test(F_glu)
pnnx_add_test(F_grid_sample)
Expand Down Expand Up @@ -70,6 +71,7 @@ pnnx_add_test(F_softsign)
pnnx_add_test(F_tanh)
pnnx_add_test(F_tanhshrink)
pnnx_add_test(F_threshold)
pnnx_add_test(F_unfold)
pnnx_add_test(F_upsample_bilinear)
pnnx_add_test(F_upsample_nearest)
pnnx_add_test(F_upsample)
Expand Down Expand Up @@ -103,6 +105,7 @@ pnnx_add_test(nn_Dropout2d)
pnnx_add_test(nn_Dropout3d)
pnnx_add_test(nn_ELU)
pnnx_add_test(nn_Embedding)
pnnx_add_test(nn_Fold)
pnnx_add_test(nn_GELU)
pnnx_add_test(nn_GLU)
pnnx_add_test(nn_GroupNorm)
Expand Down Expand Up @@ -151,6 +154,7 @@ pnnx_add_test(nn_Softsign)
pnnx_add_test(nn_Tanh)
pnnx_add_test(nn_Tanhshrink)
pnnx_add_test(nn_Threshold)
pnnx_add_test(nn_Unfold)
pnnx_add_test(nn_Upsample)
pnnx_add_test(nn_UpsamplingBilinear2d)
pnnx_add_test(nn_UpsamplingNearest2d)
Expand Down
60 changes: 60 additions & 0 deletions tools/pnnx/tests/test_F_fold.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F
from packaging import version

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y, z):
x = F.fold(x, output_size=22, kernel_size=3)
y = F.fold(y, output_size=(17,18), kernel_size=(2,4), stride=(2,1), padding=2, dilation=1)
z = F.fold(z, output_size=(5,11), kernel_size=(1,3), stride=1, padding=(2,4), dilation=1)

return x, y, z

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(1, 108, 400)
y = torch.rand(1, 96, 190)
z = torch.rand(1, 33, 153)

a0, a1, a2 = net(x, y, z)

# export torchscript
mod = torch.jit.trace(net, (x, y, z))
mod.save("test_F_fold.pt")

# torchscript to pnnx
import os
os.system("../src/pnnx test_F_fold.pt inputshape=[1,108,400],[1,96,190],[1,33,153]")

# pnnx inference
import test_F_fold_pnnx
b0, b1, b2 = test_F_fold_pnnx.test_inference()

return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2)

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)
Loading

0 comments on commit a12c24d

Please sign in to comment.