diff --git a/cmake/cuda.cmake b/cmake/cuda.cmake index ec16fbe9c4..5d105de9f5 100644 --- a/cmake/cuda.cmake +++ b/cmake/cuda.cmake @@ -16,19 +16,19 @@ find_package(CUDA REQUIRED) if (MSVC) set(CMAKE_CUDA_COMPILER ${CUDA_TOOLKIT_ROOT_DIR}/bin/nvcc.exe) - set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -Xcompiler=/wd4819,/wd4828") + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler=/wd4819,/wd4828") if (HAVE_CXX_FLAG_UTF_8) - set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -Xcompiler=/utf-8") + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler=/utf-8") endif () else () set(CMAKE_CUDA_COMPILER ${CUDA_TOOLKIT_ROOT_DIR}/bin/nvcc) # Explicitly set the cuda host compiler. Because the default host compiler # # selected by cmake maybe wrong. set(CMAKE_CUDA_HOST_COMPILER ${CMAKE_CXX_COMPILER}) - set(CUDA_NVCC_FLAGS - "${CUDA_NVCC_FLAGS} -Xcompiler=-fPIC,-Wall,-fvisibility=hidden") + set(CMAKE_CUDA_FLAGS + "${CMAKE_CUDA_FLAGS} -Xcompiler=-fPIC,-Wall,-fvisibility=hidden") if (CMAKE_CXX_COMPILER_ID MATCHES "GNU") - set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -Xcompiler=-fno-gnu-unique") + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler=-fno-gnu-unique") endif () endif () @@ -62,10 +62,12 @@ if (NOT CMAKE_CUDA_ARCHITECTURES) endif () endif () -set(CUDA_NVCC_FLAGS_DEBUG "-g -O0") -set(CUDA_NVCC_FLAGS_RELEASE "-O3") -set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS}") +set(CMAKE_CUDA_FLAGS_DEBUG "-g -O0") +set(CMAKE_CUDA_FLAGS_RELEASE "-O3") + +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMMDEPLOY_USE_CUDA=1") + if (NOT MSVC) set(CMAKE_CUDA_STANDARD 14) endif () -set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} ${CUDA_NVCC_FLAGS} ${_NVCC_FLAGS}") +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} ${_NVCC_FLAGS}") diff --git a/csrc/mmdeploy/net/CMakeLists.txt b/csrc/mmdeploy/net/CMakeLists.txt index 3b42740c27..09dca3cd72 100644 --- a/csrc/mmdeploy/net/CMakeLists.txt +++ b/csrc/mmdeploy/net/CMakeLists.txt @@ -26,5 +26,9 @@ if ("snpe" IN_LIST MMDEPLOY_TARGET_BACKENDS) add_subdirectory(snpe) endif () +if ("torchscript" IN_LIST MMDEPLOY_TARGET_BACKENDS) + add_subdirectory(torchscript) +endif () + mmdeploy_add_module(${PROJECT_NAME} net_module.cpp) add_library(mmdeploy::net_module ALIAS ${PROJECT_NAME}) diff --git a/csrc/mmdeploy/net/torchscript/CMakeLists.txt b/csrc/mmdeploy/net/torchscript/CMakeLists.txt new file mode 100644 index 0000000000..3d424ed474 --- /dev/null +++ b/csrc/mmdeploy/net/torchscript/CMakeLists.txt @@ -0,0 +1,24 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +project(mmdeploy_torch_net) + +find_package(Torch REQUIRED) +find_package(TorchVision) + +mmdeploy_add_module(${PROJECT_NAME} torch_net.cpp) + +target_link_libraries(${PROJECT_NAME} PRIVATE + ${TORCH_LIBRARIES}) + +target_link_directories(${PROJECT_NAME} INTERFACE + $) + +target_link_libraries(${PROJECT_NAME} PRIVATE + mmdeploy_torchscript_ops_obj) + +if (TorchVision_FOUND) + target_link_libraries(${PROJECT_NAME} PRIVATE TorchVision::TorchVision) + target_compile_definitions(${PROJECT_NAME} PRIVATE -DMMDEPLOY_USE_TORCHVISION=1) +endif () + +add_library(mmdeploy::torch_net ALIAS ${PROJECT_NAME}) diff --git a/csrc/mmdeploy/net/torchscript/torch_net.cpp b/csrc/mmdeploy/net/torchscript/torch_net.cpp new file mode 100644 index 0000000000..57c552048d --- /dev/null +++ b/csrc/mmdeploy/net/torchscript/torch_net.cpp @@ -0,0 +1,237 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "mmdeploy/net/torchscript/torch_net.h" + +#include "mmdeploy/core/model.h" +#include "mmdeploy/core/utils/formatter.h" +#include "torch/torch.h" + +#if MMDEPLOY_USE_CUDA +#include "c10/cuda/CUDAGuard.h" +#include "c10/cuda/CUDAStream.h" +#endif + +#if MMDEPLOY_USE_TORCHVISION +#include "torchvision/vision.h" +MMDEPLOY_API void _mmdeploy_force_link_torchvision() { vision::detail::_register_ops(); } +#endif + +namespace mmdeploy { + +namespace { + +class InferenceMode { +#if TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 10 + c10::InferenceMode guard_; +#else + at::AutoNonVariableTypeMode guard_; +#endif +}; + +class StreamGuard { + public: + StreamGuard(const torch::Device& device, Stream stream) + : device_(device), stream_(std::move(stream)), device_guard_(device) { + stream_.Wait().value(); + } + + ~StreamGuard() { +#if MMDEPLOY_USE_CUDA + auto device = stream_.GetDevice(); + if (device.is_device()) { + Stream stream(device, (cudaStream_t)c10::cuda::getCurrentCUDAStream(device_.index())); + stream.Wait().value(); + } +#endif + } + + private: + torch::Device device_; + Stream stream_; + c10::DeviceGuard device_guard_; +}; + +Result FromDataType(DataType data_type) { + switch (data_type) { + case DataType::kFLOAT: + return torch::ScalarType::Float; + case DataType::kHALF: + return torch::ScalarType::Half; + case DataType::kINT32: + return torch::ScalarType::Int; + case DataType::kINT64: + return torch::ScalarType::Long; + case DataType::kINT8: + return torch::ScalarType::Char; + default: + MMDEPLOY_ERROR("Unsupported mmdeploy::DataType: {}", to_string(data_type)); + return Status(eNotSupported); + } +} + +Result ToDataType(torch::ScalarType scalar_type) { + switch (scalar_type) { + case torch::ScalarType::Float: + return DataType::kFLOAT; + case torch::ScalarType::Half: + return DataType::kHALF; + case torch::ScalarType::Int: + return DataType::kINT32; + case torch::ScalarType::Long: + return DataType::kINT64; + case torch::ScalarType::Char: + return DataType::kINT8; + default: + MMDEPLOY_ERROR("Unsupported torch::ScalarType: {}", toString(scalar_type)); + return Status(eNotSupported); + } +} + +} // namespace + +TorchNet::~TorchNet() = default; + +Result TorchNet::Init(const Value& cfg) { + auto& context = cfg["context"]; + device_ = context["device"].get(); + stream_ = context["stream"].get(); + + auto name = cfg["name"].get(); + auto model = context["model"].get(); + + OUTCOME_TRY(auto config, model.GetModelConfig(name)); + OUTCOME_TRY(auto bytes, model.ReadFile(config.net)); + + auto platform = Platform(device_.platform_id()); + auto device_name = platform.GetPlatformName(); + + try { + { + using namespace std::string_literals; + if (device_name == "cpu"s) { + torch_device_ = torch::Device(device_name); + } else { + torch_device_ = torch::Device(device_name + ":"s + std::to_string(device_.device_id())); + } + } + std::istringstream iss(bytes); + InferenceMode guard; + module_ = torch::jit::load(iss); + module_.eval(); + module_.to(*torch_device_); + auto forward = module_.get_method("forward"); + + auto ToDesc = [&](torch::jit::Value* value, const char* type, int index) { + MMDEPLOY_INFO("Found {}: {}", type, value->debugNameBase()); + return TensorDesc{device_, DataType::kFLOAT, {}, "#" + std::to_string(index)}; + }; + + auto inputs = forward.graph()->inputs(); + int input_count = 0; + for (int i = 1; i < inputs.size(); ++i) { + if (inputs[i]->type()->kind() == c10::TypeKind::TensorType) { + input_tensor_.emplace_back(ToDesc(inputs[i], "input", input_count++)); + } else { + MMDEPLOY_ERROR("Unsupported input type: {}", typeKindToString(inputs[i]->type()->kind())); + return Status(eNotSupported); + } + } + + auto outputs = forward.graph()->outputs(); + int output_count = 0; + for (const auto& output : outputs) { + auto kind = output->type()->kind(); + if (kind == c10::TypeKind::TensorType) { + output_tensor_.emplace_back(ToDesc(output, "output", output_count++)); + } else if (output->type()->kind() == c10::TypeKind::TupleType) { + for (const auto& v : output->node()->inputs()) { + if (v->type()->kind() == c10::TypeKind::TensorType) { + output_tensor_.emplace_back(ToDesc(v, "output", output_count++)); + } else { + MMDEPLOY_ERROR("Unsupported output type: {}", typeKindToString(v->type()->kind())); + return Status(eNotSupported); + } + } + } else { + MMDEPLOY_ERROR("Unsupported output type: {}", typeKindToString(kind)); + } + } + return success(); + } catch (const std::exception& e) { + MMDEPLOY_ERROR("unhandled exception: {}", e.what()); + return Status(eFail); + } +} + +Result TorchNet::Deinit() { return success(); } +Result> TorchNet::GetInputTensors() { return input_tensor_; } +Result> TorchNet::GetOutputTensors() { return output_tensor_; } + +Result TorchNet::Reshape(Span input_shapes) { + if (input_shapes.size() != input_tensor_.size()) { + return Status(eInvalidArgument); + } + for (size_t i = 0; i < input_shapes.size(); ++i) { + input_tensor_[i].Reshape(input_shapes[i]); + } + return success(); +} + +Result TorchNet::Forward() { + try { + StreamGuard stream_guard(*torch_device_, stream_); + InferenceMode inference_guard; + std::vector inputs; + for (auto& v : input_tensor_) { + OUTCOME_TRY(auto data_type, FromDataType(v.data_type())); + auto tensor = torch::from_blob(v.data(), v.shape(), + c10::TensorOptions(*torch_device_).dtype(data_type)); + inputs.emplace_back(tensor); + } + auto outputs = module_.forward(inputs); + if (outputs.isTensor()) { + OUTCOME_TRY(output_tensor_[0], FromTorchTensor(outputs.toTensor(), output_tensor_[0].name())); + } else if (outputs.isTuple()) { + auto tuple = outputs.toTuple(); + size_t index = 0; + for (const auto& x : tuple->elements()) { + OUTCOME_TRY(output_tensor_[index], + FromTorchTensor(x.toTensor(), output_tensor_[index].name())); + ++index; + } + } else { + MMDEPLOY_ERROR("{}", toString(outputs.type())); + return Status(eNotSupported); + } + } catch (const std::exception& e) { + MMDEPLOY_ERROR("unhandled exception: {}", e.what()); + return Status(eFail); + } + return success(); +} +Result TorchNet::ForwardAsync(Event* event) { return success(); } + +Result TorchNet::FromTorchTensor(const torch::Tensor& tensor, const std::string& name) { + OUTCOME_TRY(auto data_type, ToDataType(tensor.scalar_type())); + auto shape = tensor.sizes(); + TensorDesc desc{device_, data_type, {shape.begin(), shape.end()}, name}; + return Tensor(desc, std::shared_ptr(tensor.data_ptr(), [tensor](auto) {})); +} + +class TorchNetCreator : public Creator { + public: + const char* GetName() const override { return "torchscript"; } + std::unique_ptr Create(const Value& cfg) override { + auto p = std::make_unique(); + if (auto status = p->Init(cfg)) { + return p; + } else { + MMDEPLOY_ERROR("Failed to created TorchNet with config: {}", cfg); + } + return nullptr; + } +}; + +REGISTER_MODULE(Net, TorchNetCreator); + +} // namespace mmdeploy diff --git a/csrc/mmdeploy/net/torchscript/torch_net.h b/csrc/mmdeploy/net/torchscript/torch_net.h new file mode 100644 index 0000000000..f7021470b6 --- /dev/null +++ b/csrc/mmdeploy/net/torchscript/torch_net.h @@ -0,0 +1,35 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#ifndef MMDEPLOY_CSRC_MMDEPLOY_NET_TORCHSCRIPT_TORCH_NET_H_ +#define MMDEPLOY_CSRC_MMDEPLOY_NET_TORCHSCRIPT_TORCH_NET_H_ + +#include "mmdeploy/core/net.h" +#include "torch/script.h" + +namespace mmdeploy { + +class TorchNet : public Net { + public: + ~TorchNet() override; + Result Init(const Value& cfg) override; + Result Deinit() override; + Result> GetInputTensors() override; + Result> GetOutputTensors() override; + Result Reshape(Span input_shapes) override; + Result Forward() override; + Result ForwardAsync(Event* event) override; + + private: + Result FromTorchTensor(const torch::Tensor& tensor, const std::string& name); + + torch::jit::script::Module module_; + std::vector input_tensor_; + std::vector output_tensor_; + Device device_; + Stream stream_; + std::optional torch_device_; +}; + +} // namespace mmdeploy + +#endif // MMDEPLOY_CSRC_MMDEPLOY_NET_TORCHSCRIPT_TORCH_NET_H_ diff --git a/mmdeploy/backend/sdk/export_info.py b/mmdeploy/backend/sdk/export_info.py index 8e68f413f2..aa5310ee97 100644 --- a/mmdeploy/backend/sdk/export_info.py +++ b/mmdeploy/backend/sdk/export_info.py @@ -167,16 +167,25 @@ def get_inference_info(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config, input = ['prep_output'] output = ['infer_output'] ir_config = get_ir_config(deploy_cfg) - input_names = ir_config.get('input_names', None) - input_name = input_names[0] if input_names else 'input' - input_map = dict(img=input_name) + + backend = get_backend(deploy_cfg=deploy_cfg) + if backend == Backend.TORCHSCRIPT: + output_names = ir_config.get('output_names', None) + input_map = dict(img='#0') + output_map = {name: f'#{i}' for i, name in enumerate(output_names)} + else: + input_names = ir_config.get('input_names', None) + input_name = input_names[0] if input_names else 'input' + input_map = dict(img=input_name) + output_map = {} return_dict = dict( name=name, type=type, module=module, input=input, output=output, - input_map=input_map) + input_map=input_map, + output_map=output_map) if 'use_vulkan' in deploy_cfg['backend_config']: return_dict['use_vulkan'] = deploy_cfg['backend_config']['use_vulkan'] return return_dict