diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index fd7448bc22c..38e0516c3b4 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -198,6 +198,8 @@ set(pnnx_pass_level2_SRCS pass_level2/torch_bitwise_and.cpp pass_level2/torch_bitwise_or.cpp pass_level2/torch_bitwise_xor.cpp + pass_level2/torch_bitwise_left_shift.cpp + pass_level2/torch_bitwise_right_shift.cpp pass_level2/torch_cat.cpp pass_level2/torch_chunk.cpp pass_level2/torch_clamp.cpp diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index 1f2964d8a4f..062092fe9c4 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -1134,7 +1134,7 @@ static std::string expand_expression(const Operator* op) std::string r = binaryop + "(" + a + ", " + b + ")"; exprstack.push(r); } - else if (t == "add" || t == "sub" || t == "mul" || t == "div" || t == "floor_divide" || t == "and" || t == "or" || t == "xor") + else if (t == "add" || t == "sub" || t == "mul" || t == "div" || t == "floor_divide" || t == "and" || t == "or" || t == "xor" || t == "lshift" || t == "rshift") { std::string binaryop; if (t == "add") binaryop = "+"; @@ -1145,6 +1145,8 @@ static std::string expand_expression(const Operator* op) if (t == "and") binaryop = "&"; if (t == "or") binaryop = "|"; if (t == "xor") binaryop = "^"; + if (t == "lshift") binaryop = "<<"; + if (t == "rshift") binaryop = ">>"; std::string a = exprstack.top(); exprstack.pop(); diff --git a/tools/pnnx/src/pass_level2/torch_bitwise_left_shift.cpp b/tools/pnnx/src/pass_level2/torch_bitwise_left_shift.cpp new file mode 100644 index 00000000000..4fadaad74af --- /dev/null +++ b/tools/pnnx/src/pass_level2/torch_bitwise_left_shift.cpp @@ -0,0 +1,41 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class torch_bitwise_left_shift : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 other +aten::bitwise_left_shift op_0 2 1 input other out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.bitwise_left_shift"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_bitwise_left_shift, 20) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_bitwise_right_shift.cpp b/tools/pnnx/src/pass_level2/torch_bitwise_right_shift.cpp new file mode 100644 index 00000000000..4db2560da3f --- /dev/null +++ b/tools/pnnx/src/pass_level2/torch_bitwise_right_shift.cpp @@ -0,0 +1,41 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class torch_bitwise_right_shift : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 other +aten::bitwise_right_shift op_0 2 1 input other out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.bitwise_right_shift"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_bitwise_right_shift, 20) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level3/fuse_expression.cpp b/tools/pnnx/src/pass_level3/fuse_expression.cpp index 40176d064bb..0866e1301c6 100644 --- a/tools/pnnx/src/pass_level3/fuse_expression.cpp +++ b/tools/pnnx/src/pass_level3/fuse_expression.cpp @@ -102,7 +102,7 @@ static bool operand_maybe_tensor(const Operand* operand) return operand_maybe_tensor(op->inputs[0]) || operand_maybe_tensor(op->inputs[1]); } - if (op->type == "aten::__and__" || op->type == "aten::__or__" || op->type == "aten::__xor__") + if (op->type == "aten::__and__" || op->type == "aten::__or__" || op->type == "aten::__xor__" || op->type == "aten::__lshift__" || op->type == "aten::__rshift__") { return operand_maybe_tensor(op->inputs[0]) || operand_maybe_tensor(op->inputs[1]); } @@ -321,11 +321,9 @@ static void fuse_expression(Graph& graph, Operand* operand, std::string& expr, s fuse_expression(graph, op->inputs[1], expr, inputs, foldable_constants); expr += ")"; } - else if (op->type == "aten::__and__" || op->type == "aten::__or__" || op->type == "aten::__xor__") + else if (op->type == "aten::__and__" || op->type == "aten::__or__" || op->type == "aten::__xor__" || op->type == "aten::__lshift__" || op->type == "aten::__rshift__") { - std::string mathop = op->type.substr(8, 3); - if (mathop == "or_") - mathop = "or"; + std::string mathop = op->type.substr(8, op->type.size() - 10); expr += mathop; expr += "("; @@ -485,7 +483,7 @@ void fuse_expression(Graph& graph, const std::set& foldable_constan { need_fuse = true; } - if (op->type == "aten::__and__" || op->type == "aten::__or__" || op->type == "aten::__xor__") + if (op->type == "aten::__and__" || op->type == "aten::__or__" || op->type == "aten::__xor__" || op->type == "aten::__lshift__" || op->type == "aten::__rshift__") { need_fuse = true; } diff --git a/tools/pnnx/src/pass_level5/eval_expression.cpp b/tools/pnnx/src/pass_level5/eval_expression.cpp index ec9ab0e7277..7326f1b32c9 100644 --- a/tools/pnnx/src/pass_level5/eval_expression.cpp +++ b/tools/pnnx/src/pass_level5/eval_expression.cpp @@ -45,20 +45,14 @@ static bool token_is_literal(const std::string& t) float f; iss >> std::noskipws >> f; return iss.eof() && !iss.fail(); +} - // for (size_t i = 0; i < t.size(); i++) - // { - // if (i == 0 && t[i] == '-') - // continue; - // - // if (t[i] < '0' || t[i] > '9') - // { - // if (t[i] != '.' && t[i] != 'e') - // return false; - // } - // } - // - // return true; +static bool token_is_interger_literal(const std::string& t) +{ + std::istringstream iss(t); + int f; + iss >> std::noskipws >> f; + return iss.eof() && !iss.fail(); } static std::string eval_expression(const Operator* op) @@ -317,8 +311,7 @@ static std::string eval_expression(const Operator* op) || t == "div" || t == "floor_divide" || t == "pow" - || t == "remainder" - || t == "and" || t == "or" || t == "xor") + || t == "remainder") { std::string a = exprstack.top(); exprstack.pop(); @@ -379,6 +372,50 @@ static std::string eval_expression(const Operator* op) exprstack.push(r); } } + else if (t == "and" || t == "or" || t == "xor" || t == "lshift" || t == "rshift") + { + std::string a = exprstack.top(); + exprstack.pop(); + std::string b = exprstack.top(); + exprstack.pop(); + + if (token_is_interger_literal(a) && token_is_interger_literal(b)) + { + int ai = std::stoi(a); + int bi = std::stoi(b); + + if (t == "and") + { + int r = ai & bi; + exprstack.push(std::to_string(r)); + } + if (t == "or") + { + int r = ai | bi; + exprstack.push(std::to_string(r)); + } + if (t == "xor") + { + int r = ai ^ bi; + exprstack.push(std::to_string(r)); + } + if (t == "lshift") + { + int r = ai << bi; + exprstack.push(std::to_string(r)); + } + if (t == "rshift") + { + int r = ai >> bi; + exprstack.push(std::to_string(r)); + } + } + else + { + std::string r = t + "(" + a + "," + b + ")"; + exprstack.push(r); + } + } else if (t == "[") // list { std::vector elements; diff --git a/tools/pnnx/tests/CMakeLists.txt b/tools/pnnx/tests/CMakeLists.txt index 37d0c79aedd..9b618f65f8b 100644 --- a/tools/pnnx/tests/CMakeLists.txt +++ b/tools/pnnx/tests/CMakeLists.txt @@ -298,6 +298,11 @@ if(Torch_VERSION VERSION_GREATER_EQUAL "1.9") pnnx_add_test(nn_Mish) endif() +if(Torch_VERSION VERSION_GREATER_EQUAL "1.10") + pnnx_add_test(torch_bitwise_left_shift) + pnnx_add_test(torch_bitwise_right_shift) +endif() + if(Torch_VERSION VERSION_GREATER_EQUAL "1.11") pnnx_add_test(torch_fft_ihfft2) pnnx_add_test(torch_fft_ihfftn) diff --git a/tools/pnnx/tests/test_torch_bitwise_left_shift.py b/tools/pnnx/tests/test_torch_bitwise_left_shift.py new file mode 100644 index 00000000000..cc60f144b11 --- /dev/null +++ b/tools/pnnx/tests/test_torch_bitwise_left_shift.py @@ -0,0 +1,55 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y): + out = torch.bitwise_left_shift(x, y) + return out + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.randint(10, (3, 16), dtype=torch.int) + y = torch.randint(10, (3, 16), dtype=torch.int) + + a = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_torch_bitwise_left_shift.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_torch_bitwise_left_shift.pt inputshape=[3,16]i32,[3,16]i32") + + # pnnx inference + import test_torch_bitwise_left_shift_pnnx + b = test_torch_bitwise_left_shift_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_torch_bitwise_right_shift.py b/tools/pnnx/tests/test_torch_bitwise_right_shift.py new file mode 100644 index 00000000000..59d6c9651db --- /dev/null +++ b/tools/pnnx/tests/test_torch_bitwise_right_shift.py @@ -0,0 +1,55 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y): + out = torch.bitwise_right_shift(x, y) + return out + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.randint(10, (3, 16), dtype=torch.int) + y = torch.randint(10, (3, 16), dtype=torch.int) + + a = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_torch_bitwise_right_shift.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_torch_bitwise_right_shift.pt inputshape=[3,16]i32,[3,16]i32") + + # pnnx inference + import test_torch_bitwise_right_shift_pnnx + b = test_torch_bitwise_right_shift_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1)