From a3a6d742782576377e34852a10b2f256002c344b Mon Sep 17 00:00:00 2001 From: nihui Date: Sat, 29 Oct 2022 22:10:23 +0800 Subject: [PATCH] pnnx pytorch 1.13 --- .github/workflows/pnnx.yml | 3 +++ .../src/pass_level2/F_upsample_nearest.cpp | 23 +++++++++++++++++++ tools/pnnx/src/pass_level2/torch_einsum.cpp | 23 +++++++++++++++++++ 3 files changed, 49 insertions(+) diff --git a/.github/workflows/pnnx.yml b/.github/workflows/pnnx.yml index 976f9a706d4..040f3e3d738 100644 --- a/.github/workflows/pnnx.yml +++ b/.github/workflows/pnnx.yml @@ -41,6 +41,9 @@ jobs: - torch-version: 1.12.0 torchvision-version: 0.13.0 + - torch-version: 1.13.0 + torchvision-version: 0.14.0 + steps: - uses: actions/checkout@v3 with: diff --git a/tools/pnnx/src/pass_level2/F_upsample_nearest.cpp b/tools/pnnx/src/pass_level2/F_upsample_nearest.cpp index c544e8065bb..72b78b41441 100644 --- a/tools/pnnx/src/pass_level2/F_upsample_nearest.cpp +++ b/tools/pnnx/src/pass_level2/F_upsample_nearest.cpp @@ -63,6 +63,29 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_upsample_nearest_1, 10) +class F_upsample_nearest_1_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 size +prim::Constant op_0 0 1 scale_factor value=None +aten::upsample_nearest2d op_1 3 1 input size scale_factor out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.upsample_nearest"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_upsample_nearest_1_1, 10) + class F_upsample_nearest_2 : public GraphRewriterPass { public: diff --git a/tools/pnnx/src/pass_level2/torch_einsum.cpp b/tools/pnnx/src/pass_level2/torch_einsum.cpp index 771df403c9e..f6b24757e50 100644 --- a/tools/pnnx/src/pass_level2/torch_einsum.cpp +++ b/tools/pnnx/src/pass_level2/torch_einsum.cpp @@ -38,4 +38,27 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_einsum, 20) +class torch_einsum_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 equation +pnnx.Input input_1 0 1 operands +prim::Constant op_0 0 1 path value=None +aten::einsum op_1 3 1 equation operands path out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.einsum"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_einsum_1, 20) + } // namespace pnnx