diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index c2004c8da00..ab8e7406a2b 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -174,6 +174,7 @@ set(pnnx_pass_level2_SRCS pass_level2/F_upsample_nearest.cpp pass_level2/F_upsample.cpp pass_level2/Tensor_contiguous.cpp + pass_level2/Tensor_copy.cpp pass_level2/Tensor_expand.cpp pass_level2/Tensor_expand_as.cpp pass_level2/Tensor_index.cpp @@ -314,6 +315,7 @@ set(pnnx_pass_level5_SRCS pass_level5/fuse_contiguous_view.cpp pass_level5/fuse_linear_batchnorm1d.cpp pass_level5/fuse_select_to_unbind.cpp + pass_level5/fuse_slice_copy.cpp pass_level5/fuse_slice_indices.cpp pass_level5/fuse_slice_to_tensor_split.cpp pass_level5/fuse_static_conv.cpp diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index 0f44763afc2..d688678b5c3 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -1658,6 +1658,13 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) std::string slice_expr = make_slice_expression(op); fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), slice_expr.c_str()); } + else if (op->type == "Tensor.slice_copy") + { + // slice copy expr + std::string slice_expr = make_slice_expression(op); + fprintf(pyfp, "v_%s = v_%s\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str()); + fprintf(pyfp, " v_%s[%s] = v_%s\n", sanitize_identifier(op->outputs[0]->name).c_str(), slice_expr.c_str(), sanitize_identifier(op->inputs[1]->name).c_str()); + } else if (op->type == "Tensor.index") { // index expr diff --git a/tools/pnnx/src/pass_level1.cpp b/tools/pnnx/src/pass_level1.cpp index 9b5e4460905..0aaf4d897a8 100644 --- a/tools/pnnx/src/pass_level1.cpp +++ b/tools/pnnx/src/pass_level1.cpp @@ -376,10 +376,6 @@ void pass_level1(const torch::jit::Module& mod, const std::shared_ptrkind().toDisplayString(), name); - // always treat inplace op type as non-inplace version - if (op->type.size() > 2 && op->type[op->type.size() - 2] != '_' && op->type[op->type.size() - 1] == '_') - op->type = op->type.substr(0, op->type.size() - 1); - for (int i = 0; i < (int)n->inputs().size(); i++) { const auto& in = n->input(i); diff --git a/tools/pnnx/src/pass_level2.cpp b/tools/pnnx/src/pass_level2.cpp index b461d1e5c0f..a124789f3c3 100644 --- a/tools/pnnx/src/pass_level2.cpp +++ b/tools/pnnx/src/pass_level2.cpp @@ -502,8 +502,112 @@ void pnnx_graph_rewrite(Graph& graph, const GraphRewriterPass* pass, int& opinde } } +static void fix_inplace_copy_output(Graph& graph) +{ + while (1) + { + bool matched = false; + for (size_t i = 0; i < graph.ops.size(); i++) + { + Operator* op = graph.ops[i]; + + bool is_inplace_op = op->type.size() > 2 && op->type[op->type.size() - 2] != '_' && op->type[op->type.size() - 1] == '_'; + if (!is_inplace_op) + continue; + + // replace inplace op with non-inplace version + op->type = op->type.substr(0, op->type.size() - 1); + + if (op->type == "aten::copy") + continue; + + if (op->outputs[0]->consumers.size() != 0) + continue; + + matched = true; + + // find in0 from slice / select chain + Operand* in0 = op->inputs[0]; + while (in0->producer->type == "aten::slice" || in0->producer->type == "aten::select") + { + in0 = in0->producer->inputs[0]; + } + + // append copy for inplace op + Operator* op_copy = graph.new_operator_after("aten::copy", op->name + "_copy", op); + Operand* copy_out = graph.new_operand(op->name + "_copy_out"); + + copy_out->shape = in0->shape; + + op_copy->inputs.push_back(op->inputs[0]); + op_copy->inputs.push_back(op->outputs[0]); + op->inputs[0]->consumers.push_back(op_copy); + op->outputs[0]->consumers.push_back(op_copy); + + op_copy->outputs.push_back(copy_out); + copy_out->producer = op_copy; + + break; + } + + if (!matched) + break; + } + + for (size_t i = 0; i < graph.ops.size(); i++) + { + Operator* op = graph.ops[i]; + + if (op->type != "aten::copy") + continue; + + if (op->outputs[0]->consumers.size() != 0) + continue; + + // aten::slice 5 1 in0 .... a + // aten::slice 5 1 a .... b + // aten::copy 2 1 b in1 out + + // aten::select 3 1 in0 .... a + // aten::copy 2 1 a in1 out + + // find in0 from slice / select chain + Operand* in0 = op->inputs[0]; + while (in0->producer->type == "aten::slice" || in0->producer->type == "aten::select") + { + in0 = in0->producer->inputs[0]; + } + + // replace all the following uses of in0 with out + Operand* out0 = op->outputs[0]; + out0->shape = in0->shape; + for (size_t j = i; j < graph.ops.size(); j++) + { + Operator* op2 = graph.ops[j]; + + bool use_in0 = false; + for (size_t k = 0; k < op2->inputs.size(); k++) + { + if (op2->inputs[k] == in0) + { + op2->inputs[k] = out0; + use_in0 = true; + } + } + + if (use_in0) + { + in0->remove_consumer(op2); + out0->consumers.push_back(op2); + } + } + } +} + void pass_level2(Graph& g) { + fix_inplace_copy_output(g); + int opindex = 0; for (auto x : g_global_pnnx_graph_rewriter_passes) { diff --git a/tools/pnnx/src/pass_level2/Tensor_copy.cpp b/tools/pnnx/src/pass_level2/Tensor_copy.cpp new file mode 100644 index 00000000000..d5369b29e8a --- /dev/null +++ b/tools/pnnx/src/pass_level2/Tensor_copy.cpp @@ -0,0 +1,64 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class Tensor_copy : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 self +pnnx.Input input_1 0 1 src +prim::Constant op_0 0 1 non_blocking value=* +aten::copy op_1 3 1 self src non_blocking out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Tensor.copy"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_copy, 20) + +class Tensor_copy_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 self +pnnx.Input input_1 0 1 src +aten::copy op_1 2 1 self src out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Tensor.copy"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_copy_1, 20) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5.cpp b/tools/pnnx/src/pass_level5.cpp index 611e5234c25..d38316f54dd 100644 --- a/tools/pnnx/src/pass_level5.cpp +++ b/tools/pnnx/src/pass_level5.cpp @@ -34,6 +34,7 @@ #include "pass_level5/fuse_contiguous_view.h" #include "pass_level5/fuse_linear_batchnorm1d.h" #include "pass_level5/fuse_select_to_unbind.h" +#include "pass_level5/fuse_slice_copy.h" #include "pass_level5/fuse_slice_indices.h" #include "pass_level5/fuse_slice_to_tensor_split.h" #include "pass_level5/fuse_static_conv.h" @@ -66,6 +67,8 @@ void pass_level5(Graph& g, const std::set& foldable_constants, cons fuse_slice_to_tensor_split(g); + fuse_slice_copy(g); + fuse_static_conv(g); fuse_conv1d_batchnorm1d(g); diff --git a/tools/pnnx/src/pass_level5/fold_constants.cpp b/tools/pnnx/src/pass_level5/fold_constants.cpp index 8767f2c5d78..e5bccd49827 100644 --- a/tools/pnnx/src/pass_level5/fold_constants.cpp +++ b/tools/pnnx/src/pass_level5/fold_constants.cpp @@ -22,6 +22,9 @@ namespace pnnx { void fold_constants(Graph& graph, const std::set& foldable_constants, const std::string& foldable_constants_zippath) { + if (foldable_constants.empty()) + return; + StoreZipReader zip; zip.open(foldable_constants_zippath); diff --git a/tools/pnnx/src/pass_level5/fuse_slice_copy.cpp b/tools/pnnx/src/pass_level5/fuse_slice_copy.cpp new file mode 100644 index 00000000000..0a5fabab7af --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_slice_copy.cpp @@ -0,0 +1,279 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "fuse_slice_copy.h" + +#include +#include +#include +#include "pass_level2.h" + +namespace pnnx { + +void fuse_slice_copy(Graph& graph) +{ + while (1) + { + bool matched = false; + + for (size_t i = 0; i < graph.ops.size(); i++) + { + Operator* op = graph.ops[i]; + + if (op->type != "Tensor.copy") + continue; + + // collect slice / select op chain + std::stack slice_select_ops; + int descent_dim_current = INT_MAX; + const Operand* in0 = op->inputs[0]; + while (in0->producer->type == "Tensor.slice" || in0->producer->type == "Tensor.select") + { + const Operator* sop = in0->producer; + if (sop->type == "Tensor.slice") + { + if (sop->params.find("dims") == sop->params.end() + || sop->params.find("starts") == sop->params.end() + || sop->params.find("ends") == sop->params.end() + || sop->params.find("steps") == sop->params.end()) + { + fprintf(stderr, "dynamic index in slice copy chain is not supported\n"); + break; + } + + int dims0 = sop->params.at("dims").ai[0]; + if (descent_dim_current < dims0) + { + break; + } + + descent_dim_current = dims0; + } + + if (sop->type == "Tensor.select") + { + if (sop->params.find("dim") == sop->params.end() + || sop->params.find("index") == sop->params.end()) + { + fprintf(stderr, "dynamic index in select copy chain is not supported\n"); + break; + } + + int dim = sop->params.at("dim").i; + if (descent_dim_current < dim) + { + break; + } + + descent_dim_current = dim; + } + + slice_select_ops.push(sop); + in0 = sop->inputs[0]; + } + + matched = true; + + if (slice_select_ops.empty()) + { + // eliminate noop copy + Operand* out = op->outputs[0]; + + for (auto& x : out->consumers) + { + for (size_t j = 0; j < x->inputs.size(); j++) + { + if (x->inputs[j] == out) + x->inputs[j] = op->inputs[1]; + } + + op->inputs[1]->consumers.push_back(x); + } + + op->inputs[0]->remove_consumer(op); + op->inputs[1]->remove_consumer(op); + + op->inputs[1]->name = out->name; + + out->producer = 0; + out->consumers.clear(); + + graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), out)); + delete out; + + op->inputs.clear(); + op->outputs.clear(); + + graph.ops.erase(graph.ops.begin() + i); + delete op; + + break; + } + + const Operator* top_sop = slice_select_ops.top(); + + // construct one-step slice + std::vector new_dims; + std::vector new_starts; + std::vector new_ends; + std::vector new_steps; + + int select_dims_offset = 0; + while (!slice_select_ops.empty()) + { + const Operator* sop = slice_select_ops.top(); + slice_select_ops.pop(); + + if (sop->type == "Tensor.slice") + { + std::vector dims = sop->params.at("dims").ai; + std::vector starts = sop->params.at("starts").ai; + std::vector ends = sop->params.at("ends").ai; + std::vector steps = sop->params.at("steps").ai; + + for (size_t j = 0; j < dims.size(); j++) + { + dims[j] += select_dims_offset; + } + + new_dims.insert(new_dims.end(), dims.begin(), dims.end()); + new_starts.insert(new_starts.end(), starts.begin(), starts.end()); + new_ends.insert(new_ends.end(), ends.begin(), ends.end()); + new_steps.insert(new_steps.end(), steps.begin(), steps.end()); + } + else if (sop->type == "Tensor.select") + { + int dim = sop->params.at("dim").i; + int index = sop->params.at("index").i; + + dim += select_dims_offset; + int end = index + 1; + if (index == -1) + end = INT_MAX; + + new_dims.push_back(dim); + new_starts.push_back(index); + new_ends.push_back(end); + new_steps.push_back(1); + + select_dims_offset += 1; + } + } + + op->type = "Tensor.slice_copy"; + + // insert clone before any slices + Operator* op_clone = graph.new_operator_before("Tensor.clone", op->name + "_ncnnclone", top_sop); + Operand* clone_out = graph.new_operand(op->name + "_ncnnclone_out"); + + clone_out->shape = top_sop->inputs[0]->shape; + + op_clone->inputs.push_back(top_sop->inputs[0]); + top_sop->inputs[0]->consumers.push_back(op_clone); + + op_clone->outputs.push_back(clone_out); + clone_out->producer = op_clone; + + op->inputs[0]->remove_consumer(op); + op->inputs[0] = clone_out; + clone_out->consumers.push_back(op); + + op->params["dims"] = new_dims; + op->params["starts"] = new_starts; + op->params["ends"] = new_ends; + op->params["steps"] = new_steps; + + int input_rank = (int)op->inputs[0]->shape.size(); + if (input_rank == 0) + { + // insert view_as(sliced) for different or unknown rank + Operator* op_slice = graph.new_operator_before("Tensor.slice", op->name + "_ncnnslice", op); + Operator* op_view_as = graph.new_operator_before("Tensor.view_as", op->name + "_ncnnview_as", op); + + Operand* slice_out = graph.new_operand(op->name + "_ncnnslice_out"); + Operand* view_as_out = graph.new_operand(op->name + "_ncnnview_as_out"); + + op_slice->params["dims"] = new_dims; + op_slice->params["starts"] = new_starts; + op_slice->params["ends"] = new_ends; + op_slice->params["steps"] = new_steps; + + op_slice->inputs.push_back(op->inputs[0]); + op->inputs[0]->consumers.push_back(op_slice); + + op_slice->outputs.push_back(slice_out); + slice_out->producer = op_slice; + + op_view_as->inputs.push_back(op->inputs[1]); + op->inputs[1]->consumers.push_back(op_view_as); + op->inputs[1]->remove_consumer(op); + op_view_as->inputs.push_back(slice_out); + slice_out->consumers.push_back(op_view_as); + + op_view_as->outputs.push_back(view_as_out); + view_as_out->producer = op_view_as; + + op->inputs[1] = view_as_out; + view_as_out->consumers.push_back(op); + } + else if (input_rank != (int)op->inputs[1]->shape.size()) + { + // solve the target shape + std::vector target_shape = op->inputs[0]->shape; + for (size_t j = 0; j < new_dims.size(); j++) + { + int dim = new_dims[j]; + int start = new_starts[j]; + int end = new_ends[j]; + int step = new_steps[j]; + + if (dim < 0) + dim = input_rank + dim; + if (start < 0) + start = target_shape[dim] + start; + if (end < 0) + end = target_shape[dim] + end; + if (end == INT_MAX) + end = target_shape[dim]; + + target_shape[dim] = (end - start + (step - 1)) / step; + } + + Operator* op_view = graph.new_operator_before("Tensor.view", op->name + "_ncnnview", op); + Operand* view_out = graph.new_operand(op->name + "_ncnnview_out"); + + op_view->params["shape"] = target_shape; + + view_out->shape = target_shape; + + op_view->inputs.push_back(op->inputs[1]); + op->inputs[1]->consumers.push_back(op_view); + op->inputs[1]->remove_consumer(op); + + op_view->outputs.push_back(view_out); + view_out->producer = op_view; + + op->inputs[1] = view_out; + view_out->consumers.push_back(op); + } + + break; + } + + if (!matched) + break; + } +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_slice_copy.h b/tools/pnnx/src/pass_level5/fuse_slice_copy.h new file mode 100644 index 00000000000..db3aef77359 --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_slice_copy.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +void fuse_slice_copy(Graph& graph); + +} // namespace pnnx diff --git a/tools/pnnx/tests/CMakeLists.txt b/tools/pnnx/tests/CMakeLists.txt index 4349b63ce1e..37d0c79aedd 100644 --- a/tools/pnnx/tests/CMakeLists.txt +++ b/tools/pnnx/tests/CMakeLists.txt @@ -171,6 +171,7 @@ pnnx_add_test(Tensor_repeat) pnnx_add_test(Tensor_reshape) pnnx_add_test(Tensor_select) pnnx_add_test(Tensor_slice) +pnnx_add_test(Tensor_slice_copy) pnnx_add_test(Tensor_view) pnnx_add_test(torch_addmm) diff --git a/tools/pnnx/tests/test_Tensor_slice_copy.py b/tools/pnnx/tests/test_Tensor_slice_copy.py new file mode 100644 index 00000000000..e3c76a2b867 --- /dev/null +++ b/tools/pnnx/tests/test_Tensor_slice_copy.py @@ -0,0 +1,61 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x): + x = x.clone() + x[2:10,...] += 1 + x[...,1] = x[...,-1] * 3 + y = x.clone() + x[:,:,3,::2].clamp_(0, 0.5) + x[:,:,3,::2] = x[:,:,3,::2].exp_() + x[:,:,::2,:] = y[:,:,::2,:].pow(2) + x[:,:,:,:] = x[:,:,:,:] / 2 + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(18, 15, 19, 20) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_Tensor_slice_copy.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_Tensor_slice_copy.pt inputshape=[18,15,19,20]") + + # pnnx inference + import test_Tensor_slice_copy_pnnx + b = test_Tensor_slice_copy_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1)