From 92da26be7942fecac02c1186ed572bddeb08cc7d Mon Sep 17 00:00:00 2001 From: nihui Date: Fri, 4 Nov 2022 16:58:54 +0800 Subject: [PATCH] pnnx load gpu torchscript and reset device (#4330) --- tools/pnnx/src/CMakeLists.txt | 1 + tools/pnnx/src/main.cpp | 4 +-- tools/pnnx/src/pass_level0.cpp | 7 ++-- tools/pnnx/src/pass_level0.h | 2 +- tools/pnnx/src/pass_level0/reset_device.cpp | 36 +++++++++++++++++++ tools/pnnx/src/pass_level0/reset_device.h | 21 +++++++++++ .../pnnx/src/pass_level0/shape_inference.cpp | 7 ++-- tools/pnnx/src/pass_level0/shape_inference.h | 2 +- 8 files changed, 72 insertions(+), 8 deletions(-) create mode 100644 tools/pnnx/src/pass_level0/reset_device.cpp create mode 100644 tools/pnnx/src/pass_level0/reset_device.h diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index 2e0eb5d8456..f29437f541a 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -4,6 +4,7 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}) set(pnnx_pass_level0_SRCS pass_level0/constant_unpooling.cpp pass_level0/inline_block.cpp + pass_level0/reset_device.cpp pass_level0/shape_inference.cpp ) diff --git a/tools/pnnx/src/main.cpp b/tools/pnnx/src/main.cpp index 87ecfecd648..57290fa27aa 100644 --- a/tools/pnnx/src/main.cpp +++ b/tools/pnnx/src/main.cpp @@ -327,7 +327,7 @@ int main(int argc, char** argv) try { - mod = torch::jit::load(ptpath); + mod = torch::jit::load(ptpath, (device == "gpu") ? c10::kCUDA : c10::kCPU); } catch (const c10::Error& e) { @@ -359,7 +359,7 @@ int main(int argc, char** argv) fprintf(stderr, "############# pass_level0\n"); std::map foldable_constants; - pnnx::pass_level0(mod, g, input_tensors, input_tensors2, module_operators, ptpath, foldable_constants); + pnnx::pass_level0(mod, g, input_tensors, input_tensors2, module_operators, ptpath, device, foldable_constants); // g->dump(); diff --git a/tools/pnnx/src/pass_level0.cpp b/tools/pnnx/src/pass_level0.cpp index d50f71bbe29..a76d6766e3f 100644 --- a/tools/pnnx/src/pass_level0.cpp +++ b/tools/pnnx/src/pass_level0.cpp @@ -16,19 +16,22 @@ #include "pass_level0/constant_unpooling.h" #include "pass_level0/inline_block.h" +#include "pass_level0/reset_device.h" #include "pass_level0/shape_inference.h" namespace pnnx { -void pass_level0(const torch::jit::Module& mod, std::shared_ptr& g, const std::vector& input_tensors, const std::vector& input_tensors2, const std::vector& module_operators, const std::string& ptpath, std::map& foldable_constants) +void pass_level0(const torch::jit::Module& mod, std::shared_ptr& g, const std::vector& input_tensors, const std::vector& input_tensors2, const std::vector& module_operators, const std::string& ptpath, const std::string& device, std::map& foldable_constants) { inline_block(g, module_operators); + reset_device(g, device); + constant_unpooling(g); if (!input_tensors.empty()) { - shape_inference(mod, g, input_tensors, input_tensors2, module_operators, ptpath, foldable_constants); + shape_inference(mod, g, input_tensors, input_tensors2, module_operators, ptpath, device, foldable_constants); } } diff --git a/tools/pnnx/src/pass_level0.h b/tools/pnnx/src/pass_level0.h index 11543ddc8ff..00dfc8d8ab5 100644 --- a/tools/pnnx/src/pass_level0.h +++ b/tools/pnnx/src/pass_level0.h @@ -20,7 +20,7 @@ namespace pnnx { -void pass_level0(const torch::jit::Module& mod, std::shared_ptr& g, const std::vector& input_tensors, const std::vector& input_tensors2, const std::vector& module_operators, const std::string& ptpath, std::map& foldable_constants); +void pass_level0(const torch::jit::Module& mod, std::shared_ptr& g, const std::vector& input_tensors, const std::vector& input_tensors2, const std::vector& module_operators, const std::string& ptpath, const std::string& device, std::map& foldable_constants); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level0/reset_device.cpp b/tools/pnnx/src/pass_level0/reset_device.cpp new file mode 100644 index 00000000000..b817e41a1f4 --- /dev/null +++ b/tools/pnnx/src/pass_level0/reset_device.cpp @@ -0,0 +1,36 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "reset_device.h" +#include "../pass_level1.h" + +namespace pnnx { + +void reset_device(std::shared_ptr& graph, const std::string& device) +{ + for (torch::jit::Node* n : graph->nodes()) + { + if (n->kind().toDisplayString() == std::string("aten::to")) + { + if (n->hasNamedInput("device")) + { + torch::jit::Node* device_node = n->namedInput("device")->node(); + + device_node->s_(torch::jit::attr::value, (device == "gpu") ? "cuda" : "cpu"); + } + } + } +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level0/reset_device.h b/tools/pnnx/src/pass_level0/reset_device.h new file mode 100644 index 00000000000..17d8f93995e --- /dev/null +++ b/tools/pnnx/src/pass_level0/reset_device.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include + +namespace pnnx { + +void reset_device(std::shared_ptr& graph, const std::string& device); + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level0/shape_inference.cpp b/tools/pnnx/src/pass_level0/shape_inference.cpp index cc0c19f9ccf..98e2ad45d41 100644 --- a/tools/pnnx/src/pass_level0/shape_inference.cpp +++ b/tools/pnnx/src/pass_level0/shape_inference.cpp @@ -17,6 +17,7 @@ #include "pass_level0/constant_unpooling.h" #include "pass_level0/inline_block.h" +#include "pass_level0/reset_device.h" #include "pass_level0/shape_inference.h" namespace pnnx { @@ -77,7 +78,7 @@ static bool value_link_output(const torch::jit::Value* v, const std::vector& graph, const std::vector& input_tensors, const std::vector& input_tensors2, const std::vector& module_operators, const std::string& ptpath, std::map& foldable_constants) +void shape_inference(const torch::jit::Module& mod, std::shared_ptr& graph, const std::vector& input_tensors, const std::vector& input_tensors2, const std::vector& module_operators, const std::string& ptpath, const std::string& device, std::map& foldable_constants) { // collect all intermediate output tensors std::vector > more_value_names; @@ -150,13 +151,15 @@ void shape_inference(const torch::jit::Module& mod, std::shared_ptr values2; diff --git a/tools/pnnx/src/pass_level0/shape_inference.h b/tools/pnnx/src/pass_level0/shape_inference.h index cf80ade7abe..ee2c461179d 100644 --- a/tools/pnnx/src/pass_level0/shape_inference.h +++ b/tools/pnnx/src/pass_level0/shape_inference.h @@ -18,6 +18,6 @@ namespace pnnx { -void shape_inference(const torch::jit::Module& mod, std::shared_ptr& graph, const std::vector& input_tensors, const std::vector& input_tensors2, const std::vector& module_operators, const std::string& ptpath, std::map& foldable_constants); +void shape_inference(const torch::jit::Module& mod, std::shared_ptr& graph, const std::vector& input_tensors, const std::vector& input_tensors2, const std::vector& module_operators, const std::string& ptpath, const std::string& device, std::map& foldable_constants); } // namespace pnnx