Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] TorchScript SDK backend #890

Merged
merged 11 commits into from
Aug 29, 2022
20 changes: 11 additions & 9 deletions cmake/cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,19 @@ find_package(CUDA REQUIRED)

if (MSVC)
set(CMAKE_CUDA_COMPILER ${CUDA_TOOLKIT_ROOT_DIR}/bin/nvcc.exe)
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -Xcompiler=/wd4819,/wd4828")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler=/wd4819,/wd4828")
if (HAVE_CXX_FLAG_UTF_8)
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -Xcompiler=/utf-8")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler=/utf-8")
endif ()
else ()
set(CMAKE_CUDA_COMPILER ${CUDA_TOOLKIT_ROOT_DIR}/bin/nvcc)
# Explicitly set the cuda host compiler. Because the default host compiler #
# selected by cmake maybe wrong.
set(CMAKE_CUDA_HOST_COMPILER ${CMAKE_CXX_COMPILER})
set(CUDA_NVCC_FLAGS
"${CUDA_NVCC_FLAGS} -Xcompiler=-fPIC,-Wall,-fvisibility=hidden")
set(CMAKE_CUDA_FLAGS
"${CMAKE_CUDA_FLAGS} -Xcompiler=-fPIC,-Wall,-fvisibility=hidden")
if (CMAKE_CXX_COMPILER_ID MATCHES "GNU")
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -Xcompiler=-fno-gnu-unique")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler=-fno-gnu-unique")
endif ()
endif ()

Expand Down Expand Up @@ -62,10 +62,12 @@ if (NOT CMAKE_CUDA_ARCHITECTURES)
endif ()
endif ()

set(CUDA_NVCC_FLAGS_DEBUG "-g -O0")
set(CUDA_NVCC_FLAGS_RELEASE "-O3")
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS}")
set(CMAKE_CUDA_FLAGS_DEBUG "-g -O0")
set(CMAKE_CUDA_FLAGS_RELEASE "-O3")

set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMMDEPLOY_USE_CUDA=1")

if (NOT MSVC)
set(CMAKE_CUDA_STANDARD 14)
endif ()
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} ${CUDA_NVCC_FLAGS} ${_NVCC_FLAGS}")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} ${_NVCC_FLAGS}")
4 changes: 4 additions & 0 deletions csrc/mmdeploy/net/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,9 @@ if ("snpe" IN_LIST MMDEPLOY_TARGET_BACKENDS)
add_subdirectory(snpe)
endif ()

if ("torchscript" IN_LIST MMDEPLOY_TARGET_BACKENDS)
add_subdirectory(torchscript)
endif ()

mmdeploy_add_module(${PROJECT_NAME} net_module.cpp)
add_library(mmdeploy::net_module ALIAS ${PROJECT_NAME})
24 changes: 24 additions & 0 deletions csrc/mmdeploy/net/torchscript/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright (c) OpenMMLab. All rights reserved.

project(mmdeploy_torch_net)

find_package(Torch REQUIRED)
find_package(TorchVision)
lvhan028 marked this conversation as resolved.
Show resolved Hide resolved

mmdeploy_add_module(${PROJECT_NAME} torch_net.cpp)

target_link_libraries(${PROJECT_NAME} PRIVATE
${TORCH_LIBRARIES})

target_link_directories(${PROJECT_NAME} INTERFACE
$<BUILD_INTERFACE:${Torch_DIR}/../../../lib>)

target_link_libraries(${PROJECT_NAME} PRIVATE
mmdeploy_torchscript_ops_obj)

if (TorchVision_FOUND)
target_link_libraries(${PROJECT_NAME} PRIVATE TorchVision::TorchVision)
target_compile_definitions(${PROJECT_NAME} PRIVATE -DMMDEPLOY_USE_TORCHVISION=1)
endif ()

add_library(mmdeploy::torch_net ALIAS ${PROJECT_NAME})
237 changes: 237 additions & 0 deletions csrc/mmdeploy/net/torchscript/torch_net.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
// Copyright (c) OpenMMLab. All rights reserved.

#include "mmdeploy/net/torchscript/torch_net.h"

#include "mmdeploy/core/model.h"
#include "mmdeploy/core/utils/formatter.h"
#include "torch/torch.h"

#if MMDEPLOY_USE_CUDA
#include "c10/cuda/CUDAGuard.h"
#include "c10/cuda/CUDAStream.h"
#endif

#if MMDEPLOY_USE_TORCHVISION
#include "torchvision/vision.h"
MMDEPLOY_API void _mmdeploy_force_link_torchvision() { vision::detail::_register_ops(); }
#endif

namespace mmdeploy {

namespace {

class InferenceMode {
#if TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 10
c10::InferenceMode guard_;
#else
at::AutoNonVariableTypeMode guard_;
#endif
};

class StreamGuard {
public:
StreamGuard(const torch::Device& device, Stream stream)
: device_(device), stream_(std::move(stream)), device_guard_(device) {
stream_.Wait().value();
}

~StreamGuard() {
#if MMDEPLOY_USE_CUDA
auto device = stream_.GetDevice();
if (device.is_device()) {
Stream stream(device, (cudaStream_t)c10::cuda::getCurrentCUDAStream(device_.index()));
stream.Wait().value();
}
#endif
}

private:
torch::Device device_;
Stream stream_;
c10::DeviceGuard device_guard_;
};

Result<torch::ScalarType> FromDataType(DataType data_type) {
switch (data_type) {
case DataType::kFLOAT:
return torch::ScalarType::Float;
case DataType::kHALF:
return torch::ScalarType::Half;
case DataType::kINT32:
return torch::ScalarType::Int;
case DataType::kINT64:
return torch::ScalarType::Long;
case DataType::kINT8:
return torch::ScalarType::Char;
default:
MMDEPLOY_ERROR("Unsupported mmdeploy::DataType: {}", to_string(data_type));
return Status(eNotSupported);
}
}

Result<DataType> ToDataType(torch::ScalarType scalar_type) {
switch (scalar_type) {
case torch::ScalarType::Float:
return DataType::kFLOAT;
case torch::ScalarType::Half:
return DataType::kHALF;
case torch::ScalarType::Int:
return DataType::kINT32;
case torch::ScalarType::Long:
return DataType::kINT64;
case torch::ScalarType::Char:
return DataType::kINT8;
default:
MMDEPLOY_ERROR("Unsupported torch::ScalarType: {}", toString(scalar_type));
return Status(eNotSupported);
}
}

} // namespace

TorchNet::~TorchNet() = default;

Result<void> TorchNet::Init(const Value& cfg) {
auto& context = cfg["context"];
device_ = context["device"].get<Device>();
stream_ = context["stream"].get<Stream>();

auto name = cfg["name"].get<std::string>();
auto model = context["model"].get<Model>();

OUTCOME_TRY(auto config, model.GetModelConfig(name));
OUTCOME_TRY(auto bytes, model.ReadFile(config.net));

auto platform = Platform(device_.platform_id());
auto device_name = platform.GetPlatformName();

try {
{
using namespace std::string_literals;
if (device_name == "cpu"s) {
torch_device_ = torch::Device(device_name);
} else {
torch_device_ = torch::Device(device_name + ":"s + std::to_string(device_.device_id()));
}
}
std::istringstream iss(bytes);
InferenceMode guard;
module_ = torch::jit::load(iss);
module_.eval();
module_.to(*torch_device_);
auto forward = module_.get_method("forward");

auto ToDesc = [&](torch::jit::Value* value, const char* type, int index) {
MMDEPLOY_INFO("Found {}: {}", type, value->debugNameBase());
return TensorDesc{device_, DataType::kFLOAT, {}, "#" + std::to_string(index)};
};

auto inputs = forward.graph()->inputs();
int input_count = 0;
for (int i = 1; i < inputs.size(); ++i) {
if (inputs[i]->type()->kind() == c10::TypeKind::TensorType) {
input_tensor_.emplace_back(ToDesc(inputs[i], "input", input_count++));
} else {
MMDEPLOY_ERROR("Unsupported input type: {}", typeKindToString(inputs[i]->type()->kind()));
return Status(eNotSupported);
}
}

auto outputs = forward.graph()->outputs();
int output_count = 0;
for (const auto& output : outputs) {
auto kind = output->type()->kind();
if (kind == c10::TypeKind::TensorType) {
output_tensor_.emplace_back(ToDesc(output, "output", output_count++));
} else if (output->type()->kind() == c10::TypeKind::TupleType) {
for (const auto& v : output->node()->inputs()) {
if (v->type()->kind() == c10::TypeKind::TensorType) {
output_tensor_.emplace_back(ToDesc(v, "output", output_count++));
} else {
MMDEPLOY_ERROR("Unsupported output type: {}", typeKindToString(v->type()->kind()));
return Status(eNotSupported);
}
}
} else {
MMDEPLOY_ERROR("Unsupported output type: {}", typeKindToString(kind));
}
}
return success();
} catch (const std::exception& e) {
MMDEPLOY_ERROR("unhandled exception: {}", e.what());
return Status(eFail);
}
}

Result<void> TorchNet::Deinit() { return success(); }
Result<Span<Tensor>> TorchNet::GetInputTensors() { return input_tensor_; }
Result<Span<Tensor>> TorchNet::GetOutputTensors() { return output_tensor_; }

Result<void> TorchNet::Reshape(Span<TensorShape> input_shapes) {
if (input_shapes.size() != input_tensor_.size()) {
return Status(eInvalidArgument);
}
for (size_t i = 0; i < input_shapes.size(); ++i) {
input_tensor_[i].Reshape(input_shapes[i]);
}
return success();
}

Result<void> TorchNet::Forward() {
try {
StreamGuard stream_guard(*torch_device_, stream_);
InferenceMode inference_guard;
std::vector<torch::jit::IValue> inputs;
for (auto& v : input_tensor_) {
OUTCOME_TRY(auto data_type, FromDataType(v.data_type()));
auto tensor = torch::from_blob(v.data(), v.shape(),
c10::TensorOptions(*torch_device_).dtype(data_type));
inputs.emplace_back(tensor);
}
auto outputs = module_.forward(inputs);
if (outputs.isTensor()) {
OUTCOME_TRY(output_tensor_[0], FromTorchTensor(outputs.toTensor(), output_tensor_[0].name()));
} else if (outputs.isTuple()) {
auto tuple = outputs.toTuple();
size_t index = 0;
for (const auto& x : tuple->elements()) {
OUTCOME_TRY(output_tensor_[index],
FromTorchTensor(x.toTensor(), output_tensor_[index].name()));
++index;
}
} else {
MMDEPLOY_ERROR("{}", toString(outputs.type()));
return Status(eNotSupported);
}
} catch (const std::exception& e) {
MMDEPLOY_ERROR("unhandled exception: {}", e.what());
lzhangzz marked this conversation as resolved.
Show resolved Hide resolved
return Status(eFail);
}
return success();
}
Result<void> TorchNet::ForwardAsync(Event* event) { return success(); }

Result<Tensor> TorchNet::FromTorchTensor(const torch::Tensor& tensor, const std::string& name) {
OUTCOME_TRY(auto data_type, ToDataType(tensor.scalar_type()));
auto shape = tensor.sizes();
TensorDesc desc{device_, data_type, {shape.begin(), shape.end()}, name};
return Tensor(desc, std::shared_ptr<void>(tensor.data_ptr(), [tensor](auto) {}));
}

class TorchNetCreator : public Creator<Net> {
public:
const char* GetName() const override { return "torchscript"; }
std::unique_ptr<Net> Create(const Value& cfg) override {
auto p = std::make_unique<TorchNet>();
if (auto status = p->Init(cfg)) {
return p;
} else {
MMDEPLOY_ERROR("Failed to created TorchNet with config: {}", cfg);
}
return nullptr;
}
};

REGISTER_MODULE(Net, TorchNetCreator);

} // namespace mmdeploy
35 changes: 35 additions & 0 deletions csrc/mmdeploy/net/torchscript/torch_net.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (c) OpenMMLab. All rights reserved.

#ifndef MMDEPLOY_CSRC_MMDEPLOY_NET_TORCHSCRIPT_TORCH_NET_H_
#define MMDEPLOY_CSRC_MMDEPLOY_NET_TORCHSCRIPT_TORCH_NET_H_

#include "mmdeploy/core/net.h"
#include "torch/script.h"

namespace mmdeploy {

class TorchNet : public Net {
public:
~TorchNet() override;
Result<void> Init(const Value& cfg) override;
Result<void> Deinit() override;
Result<Span<Tensor>> GetInputTensors() override;
Result<Span<Tensor>> GetOutputTensors() override;
Result<void> Reshape(Span<TensorShape> input_shapes) override;
Result<void> Forward() override;
Result<void> ForwardAsync(Event* event) override;

private:
Result<Tensor> FromTorchTensor(const torch::Tensor& tensor, const std::string& name);

torch::jit::script::Module module_;
std::vector<Tensor> input_tensor_;
std::vector<Tensor> output_tensor_;
Device device_;
Stream stream_;
std::optional<torch::Device> torch_device_;
};

} // namespace mmdeploy

#endif // MMDEPLOY_CSRC_MMDEPLOY_NET_TORCHSCRIPT_TORCH_NET_H_
17 changes: 13 additions & 4 deletions mmdeploy/backend/sdk/export_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,16 +167,25 @@ def get_inference_info(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config,
input = ['prep_output']
output = ['infer_output']
ir_config = get_ir_config(deploy_cfg)
input_names = ir_config.get('input_names', None)
input_name = input_names[0] if input_names else 'input'
input_map = dict(img=input_name)

backend = get_backend(deploy_cfg=deploy_cfg)
if backend == Backend.TORCHSCRIPT:
output_names = ir_config.get('output_names', None)
input_map = dict(img='#0')
output_map = {name: f'#{i}' for i, name in enumerate(output_names)}
else:
input_names = ir_config.get('input_names', None)
input_name = input_names[0] if input_names else 'input'
input_map = dict(img=input_name)
output_map = {}
return_dict = dict(
name=name,
type=type,
module=module,
input=input,
output=output,
input_map=input_map)
input_map=input_map,
output_map=output_map)
if 'use_vulkan' in deploy_cfg['backend_config']:
return_dict['use_vulkan'] = deploy_cfg['backend_config']['use_vulkan']
return return_dict
Expand Down