Skip to content

Commit

Permalink
pnnx load gpu torchscript and reset device (#4330)
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui authored Nov 4, 2022
1 parent 5b28c17 commit 92da26b
Show file tree
Hide file tree
Showing 8 changed files with 72 additions and 8 deletions.
1 change: 1 addition & 0 deletions tools/pnnx/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR})
set(pnnx_pass_level0_SRCS
pass_level0/constant_unpooling.cpp
pass_level0/inline_block.cpp
pass_level0/reset_device.cpp
pass_level0/shape_inference.cpp
)

Expand Down
4 changes: 2 additions & 2 deletions tools/pnnx/src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ int main(int argc, char** argv)

try
{
mod = torch::jit::load(ptpath);
mod = torch::jit::load(ptpath, (device == "gpu") ? c10::kCUDA : c10::kCPU);
}
catch (const c10::Error& e)
{
Expand Down Expand Up @@ -359,7 +359,7 @@ int main(int argc, char** argv)
fprintf(stderr, "############# pass_level0\n");

std::map<std::string, pnnx::Attribute> foldable_constants;
pnnx::pass_level0(mod, g, input_tensors, input_tensors2, module_operators, ptpath, foldable_constants);
pnnx::pass_level0(mod, g, input_tensors, input_tensors2, module_operators, ptpath, device, foldable_constants);

// g->dump();

Expand Down
7 changes: 5 additions & 2 deletions tools/pnnx/src/pass_level0.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,22 @@

#include "pass_level0/constant_unpooling.h"
#include "pass_level0/inline_block.h"
#include "pass_level0/reset_device.h"
#include "pass_level0/shape_inference.h"

namespace pnnx {

void pass_level0(const torch::jit::Module& mod, std::shared_ptr<torch::jit::Graph>& g, const std::vector<at::Tensor>& input_tensors, const std::vector<at::Tensor>& input_tensors2, const std::vector<std::string>& module_operators, const std::string& ptpath, std::map<std::string, Attribute>& foldable_constants)
void pass_level0(const torch::jit::Module& mod, std::shared_ptr<torch::jit::Graph>& g, const std::vector<at::Tensor>& input_tensors, const std::vector<at::Tensor>& input_tensors2, const std::vector<std::string>& module_operators, const std::string& ptpath, const std::string& device, std::map<std::string, Attribute>& foldable_constants)
{
inline_block(g, module_operators);

reset_device(g, device);

constant_unpooling(g);

if (!input_tensors.empty())
{
shape_inference(mod, g, input_tensors, input_tensors2, module_operators, ptpath, foldable_constants);
shape_inference(mod, g, input_tensors, input_tensors2, module_operators, ptpath, device, foldable_constants);
}
}

Expand Down
2 changes: 1 addition & 1 deletion tools/pnnx/src/pass_level0.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

namespace pnnx {

void pass_level0(const torch::jit::Module& mod, std::shared_ptr<torch::jit::Graph>& g, const std::vector<at::Tensor>& input_tensors, const std::vector<at::Tensor>& input_tensors2, const std::vector<std::string>& module_operators, const std::string& ptpath, std::map<std::string, Attribute>& foldable_constants);
void pass_level0(const torch::jit::Module& mod, std::shared_ptr<torch::jit::Graph>& g, const std::vector<at::Tensor>& input_tensors, const std::vector<at::Tensor>& input_tensors2, const std::vector<std::string>& module_operators, const std::string& ptpath, const std::string& device, std::map<std::string, Attribute>& foldable_constants);

} // namespace pnnx

Expand Down
36 changes: 36 additions & 0 deletions tools/pnnx/src/pass_level0/reset_device.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "reset_device.h"
#include "../pass_level1.h"

namespace pnnx {

void reset_device(std::shared_ptr<torch::jit::Graph>& graph, const std::string& device)
{
for (torch::jit::Node* n : graph->nodes())
{
if (n->kind().toDisplayString() == std::string("aten::to"))
{
if (n->hasNamedInput("device"))
{
torch::jit::Node* device_node = n->namedInput("device")->node();

device_node->s_(torch::jit::attr::value, (device == "gpu") ? "cuda" : "cpu");
}
}
}
}

} // namespace pnnx
21 changes: 21 additions & 0 deletions tools/pnnx/src/pass_level0/reset_device.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include <torch/script.h>

namespace pnnx {

void reset_device(std::shared_ptr<torch::jit::Graph>& graph, const std::string& device);

} // namespace pnnx
7 changes: 5 additions & 2 deletions tools/pnnx/src/pass_level0/shape_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include "pass_level0/constant_unpooling.h"
#include "pass_level0/inline_block.h"
#include "pass_level0/reset_device.h"
#include "pass_level0/shape_inference.h"

namespace pnnx {
Expand Down Expand Up @@ -77,7 +78,7 @@ static bool value_link_output(const torch::jit::Value* v, const std::vector<torc
return false;
}

void shape_inference(const torch::jit::Module& mod, std::shared_ptr<torch::jit::Graph>& graph, const std::vector<at::Tensor>& input_tensors, const std::vector<at::Tensor>& input_tensors2, const std::vector<std::string>& module_operators, const std::string& ptpath, std::map<std::string, Attribute>& foldable_constants)
void shape_inference(const torch::jit::Module& mod, std::shared_ptr<torch::jit::Graph>& graph, const std::vector<at::Tensor>& input_tensors, const std::vector<at::Tensor>& input_tensors2, const std::vector<std::string>& module_operators, const std::string& ptpath, const std::string& device, std::map<std::string, Attribute>& foldable_constants)
{
// collect all intermediate output tensors
std::vector<std::unordered_set<std::string> > more_value_names;
Expand Down Expand Up @@ -150,13 +151,15 @@ void shape_inference(const torch::jit::Module& mod, std::shared_ptr<torch::jit::

// auto mod2 = mod.deepcopy();

torch::jit::Module mod2 = torch::jit::load(ptpath);
torch::jit::Module mod2 = torch::jit::load(ptpath, (device == "gpu") ? c10::kCUDA : c10::kCPU);
mod2.eval();

auto graph2 = mod2.get_method("forward").graph();

inline_block(graph2, module_operators);

reset_device(graph2, device);

constant_unpooling(graph2);

std::vector<torch::jit::Value*> values2;
Expand Down
2 changes: 1 addition & 1 deletion tools/pnnx/src/pass_level0/shape_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@

namespace pnnx {

void shape_inference(const torch::jit::Module& mod, std::shared_ptr<torch::jit::Graph>& graph, const std::vector<at::Tensor>& input_tensors, const std::vector<at::Tensor>& input_tensors2, const std::vector<std::string>& module_operators, const std::string& ptpath, std::map<std::string, Attribute>& foldable_constants);
void shape_inference(const torch::jit::Module& mod, std::shared_ptr<torch::jit::Graph>& graph, const std::vector<at::Tensor>& input_tensors, const std::vector<at::Tensor>& input_tensors2, const std::vector<std::string>& module_operators, const std::string& ptpath, const std::string& device, std::map<std::string, Attribute>& foldable_constants);

} // namespace pnnx

0 comments on commit 92da26b

Please sign in to comment.