From baa565975a6707567e0d260afeb89bb2f5f1f26f Mon Sep 17 00:00:00 2001 From: Anurag Dixit Date: Mon, 1 Aug 2022 16:37:23 -0700 Subject: [PATCH 1/3] feat(cpp): Added support for loading runtime custom torch op and custom converters in torchtrtc Signed-off-by: Anurag Dixit --- cpp/bin/torchtrtc/BUILD | 3 ++ cpp/bin/torchtrtc/CMakeLists.txt | 2 +- cpp/bin/torchtrtc/main.cpp | 79 +++++++++++++++++++++++++++++++- 3 files changed, 81 insertions(+), 3 deletions(-) diff --git a/cpp/bin/torchtrtc/BUILD b/cpp/bin/torchtrtc/BUILD index 9265948b97..9d58e3211b 100644 --- a/cpp/bin/torchtrtc/BUILD +++ b/cpp/bin/torchtrtc/BUILD @@ -19,6 +19,9 @@ cc_binary( "parser_util.h", "parser_util.cpp" ], + linkopts = [ + "-l:libdl.so" + ], deps = [ "//third_party/args", "//cpp:torch_tensorrt", diff --git a/cpp/bin/torchtrtc/CMakeLists.txt b/cpp/bin/torchtrtc/CMakeLists.txt index 0ebfd87609..b12461e12a 100644 --- a/cpp/bin/torchtrtc/CMakeLists.txt +++ b/cpp/bin/torchtrtc/CMakeLists.txt @@ -10,7 +10,7 @@ add_executable(${executable_name} if (MSVC) target_link_libraries(${executable_name} PRIVATE torch torchtrt) else() - target_link_libraries(${executable_name} PRIVATE torch "-Wl,--no-as-needed" torchtrt "-Wl,--as-needed") + target_link_libraries(${executable_name} PRIVATE torch "-Wl,--no-as-needed -ldl" torchtrt "-Wl,--as-needed") set_target_properties( ${executable_name} PROPERTIES INSTALL_RPATH_USE_LINK_PATH FALSE # diff --git a/cpp/bin/torchtrtc/main.cpp b/cpp/bin/torchtrtc/main.cpp index 6c207d78da..b8ec67024a 100644 --- a/cpp/bin/torchtrtc/main.cpp +++ b/cpp/bin/torchtrtc/main.cpp @@ -15,6 +15,47 @@ #include "luts.h" #include "parser_util.h" +#if defined(_WIN32) +#include +#else +#include +#endif + +void load_library(std::string& plugin, std::string option, void* handle) { +#if defined(_WIN32) + handle = LoadLibrary(plugin.c_str()); +#else + handle = dlopen(plugin.c_str(), RTLD_LAZY); +#endif + if (handle == nullptr) { + torchtrt::logging::log( + torchtrt::logging::Level::kERROR, std::string("Could not load custom library " + plugin + " for " + option)); + } else { + torchtrt::logging::log( + torchtrt::logging::Level::kINFO, std::string("Loaded custom library " + plugin + " for " + option)); + } +} + +void unload_library(void* custom_lib, std::string& name) { +#if defined(_WIN32) + auto status = FreeLibrary(custom_lib); + // Return status non-zero for success + if (status) { + torchtrt::logging::log(torchtrt::logging::Level::kINFO, std::string("Unloaded custom library " + name)); + } else { + torchtrt::logging::log(torchtrt::logging::Level::kERROR, std::string("Could not unload custom library " + name)); + } +#else + auto status = dlclose(custom_lib); + // Return status 0 for success + if (!status) { + torchtrt::logging::log(torchtrt::logging::Level::kINFO, std::string("Unloaded custom library " + name)); + } else { + torchtrt::logging::log(torchtrt::logging::Level::kERROR, std::string("Could not unload custom library " + name)); + } +#endif +} + int main(int argc, char** argv) { torchtrt::logging::set_is_colored_output_on(true); torchtrt::logging::set_reportable_log_level(torchtrt::logging::Level::kWARNING); @@ -117,8 +158,7 @@ int main(int argc, char** argv) { parser, "num_iters", "Number of averaging timing iterations used to select kernels", {"num-avg-timing-iters"}); args::ValueFlag workspace_size( parser, "workspace_size", "Maximum size of workspace given to TensorRT", {"workspace-size"}); - args::ValueFlag dla_sram_size( - parser, "dla_sram_size", "DLA managed SRAM size", {"dla-sram-size"}); + args::ValueFlag dla_sram_size(parser, "dla_sram_size", "DLA managed SRAM size", {"dla-sram-size"}); args::ValueFlag dla_local_dram_size( parser, "dla_local_dram_size", "DLA Local DRAM size", {"dla-local-dram-size"}); args::ValueFlag dla_global_dram_size( @@ -147,6 +187,12 @@ int main(int argc, char** argv) { "save_engine", "Instead of compiling a full a TorchScript program, save the created engine to the path specified as the output path", {"save-engine"}); + args::ValueFlagList custom_torch_ops( + parser, "custom-torch-ops", "Shared object/DLL containing custom torch operator", {"custom-torch-ops"}); + + args::ValueFlagList custom_converters( + parser, "custom-converters", "Shared object/DLL containing custom converters", {"custom-converters"}); + args::Positional input_path(parser, "input_file_path", "Path to input TorchScript file"); args::Positional output_path( parser, "output_file_path", "Path for compiled TorchScript (or TensorRT engine) file"); @@ -174,6 +220,23 @@ int main(int argc, char** argv) { torchtrt::logging::set_reportable_log_level(torchtrt::logging::Level::kERROR); } + std::vector> custom_torch_op, custom_converter_op; + if (custom_torch_ops) { + for (auto& op : args::get(custom_torch_ops)) { + void* handle{nullptr}; + load_library(op, "custom_torch_ops", handle); + custom_torch_op.push_back({op, handle}); + } + } + + if (custom_converters) { + for (auto& op : args::get(custom_converters)) { + void* handle{nullptr}; + load_library(op, "custom_converters", handle); + custom_converter_op.push_back({op, handle}); + } + } + auto real_input_path = torchtrtc::fileio::resolve_path(args::get(input_path)); if (check_method_op_support) { @@ -477,5 +540,17 @@ int main(int argc, char** argv) { trt_mod.save(real_output_path); } + if (custom_torch_ops) { + for (auto& p : custom_torch_op) { + unload_library(p.second, p.first); + } + } + + if (custom_converters) { + for (auto& p : custom_converter_op) { + unload_library(p.second, p.first); + } + } + return 0; } From 79c82595e6b3eaaeaa688df2889fa3cc5c6f48f8 Mon Sep 17 00:00:00 2001 From: Anurag Dixit Date: Mon, 1 Aug 2022 20:27:59 -0700 Subject: [PATCH 2/3] feat(//cpp): Fixed the failure for custom library loading Signed-off-by: Anurag Dixit --- cpp/bin/torchtrtc/main.cpp | 87 ++++++++++++++++++++++---------------- 1 file changed, 51 insertions(+), 36 deletions(-) diff --git a/cpp/bin/torchtrtc/main.cpp b/cpp/bin/torchtrtc/main.cpp index b8ec67024a..1a004b7133 100644 --- a/cpp/bin/torchtrtc/main.cpp +++ b/cpp/bin/torchtrtc/main.cpp @@ -21,39 +21,25 @@ #include #endif -void load_library(std::string& plugin, std::string option, void* handle) { +void* load_library(std::string& custom_lib) { + void* handle = {nullptr}; #if defined(_WIN32) - handle = LoadLibrary(plugin.c_str()); + handle = LoadLibrary(custom_lib.c_str()); #else - handle = dlopen(plugin.c_str(), RTLD_LAZY); + handle = dlopen(custom_lib.c_str(), RTLD_LAZY); #endif - if (handle == nullptr) { - torchtrt::logging::log( - torchtrt::logging::Level::kERROR, std::string("Could not load custom library " + plugin + " for " + option)); - } else { - torchtrt::logging::log( - torchtrt::logging::Level::kINFO, std::string("Loaded custom library " + plugin + " for " + option)); - } + return handle; } -void unload_library(void* custom_lib, std::string& name) { +bool unload_library(void* custom_lib) { + bool success = false; #if defined(_WIN32) - auto status = FreeLibrary(custom_lib); - // Return status non-zero for success - if (status) { - torchtrt::logging::log(torchtrt::logging::Level::kINFO, std::string("Unloaded custom library " + name)); - } else { - torchtrt::logging::log(torchtrt::logging::Level::kERROR, std::string("Could not unload custom library " + name)); - } + // Returns status non-zero for success + success = FreeLibrary(custom_lib) ? true : false; #else - auto status = dlclose(custom_lib); - // Return status 0 for success - if (!status) { - torchtrt::logging::log(torchtrt::logging::Level::kINFO, std::string("Unloaded custom library " + name)); - } else { - torchtrt::logging::log(torchtrt::logging::Level::kERROR, std::string("Could not unload custom library " + name)); - } + success = dlclose(custom_lib) ? false : true; #endif + return success; } int main(int argc, char** argv) { @@ -188,10 +174,16 @@ int main(int argc, char** argv) { "Instead of compiling a full a TorchScript program, save the created engine to the path specified as the output path", {"save-engine"}); args::ValueFlagList custom_torch_ops( - parser, "custom-torch-ops", "Shared object/DLL containing custom torch operator", {"custom-torch-ops"}); + parser, + "custom-torch-ops", + "(repeatable) Shared object/DLL containing custom torch operator", + {"custom-torch-ops"}); args::ValueFlagList custom_converters( - parser, "custom-converters", "Shared object/DLL containing custom converters", {"custom-converters"}); + parser, + "custom-converters", + "(repeatable) Shared object/DLL containing custom converters", + {"custom-converters"}); args::Positional input_path(parser, "input_file_path", "Path to input TorchScript file"); args::Positional output_path( @@ -223,17 +215,28 @@ int main(int argc, char** argv) { std::vector> custom_torch_op, custom_converter_op; if (custom_torch_ops) { for (auto& op : args::get(custom_torch_ops)) { - void* handle{nullptr}; - load_library(op, "custom_torch_ops", handle); - custom_torch_op.push_back({op, handle}); + auto* handle = load_library(op); + if (handle == nullptr) { + torchtrt::logging::log( + torchtrt::logging::Level::kERROR, std::string("Could not load custom_torch_ops library " + op)); + } else { + torchtrt::logging::log(torchtrt::logging::Level::kINFO, std::string("Loaded custom_torch_ops library " + op)); + + custom_torch_op.push_back({op, handle}); + } } } if (custom_converters) { for (auto& op : args::get(custom_converters)) { - void* handle{nullptr}; - load_library(op, "custom_converters", handle); - custom_converter_op.push_back({op, handle}); + auto* handle = load_library(op); + if (handle == nullptr) { + torchtrt::logging::log( + torchtrt::logging::Level::kERROR, std::string("Could not load custom_converter library " + op)); + } else { + torchtrt::logging::log(torchtrt::logging::Level::kINFO, std::string("Loaded custom_converter library " + op)); + custom_converter_op.push_back({op, handle}); + } } } @@ -252,7 +255,7 @@ int main(int argc, char** argv) { auto method = args::get(check_method_op_support); auto result = torchtrt::ts::check_method_operator_support(mod, method); if (result) { - std::cout << "The method is supported end to end by Torch-TensorRT" << std::endl; + torchtrt::logging::log(torchtrt::logging::Level::kINFO, "The method is supported end to end by Torch-TensorRT"); return 0; } else { torchtrt::logging::log(torchtrt::logging::Level::kERROR, "Method is not currently supported by Torch-TensorRT"); @@ -542,13 +545,25 @@ int main(int argc, char** argv) { if (custom_torch_ops) { for (auto& p : custom_torch_op) { - unload_library(p.second, p.first); + auto status = unload_library(p.second); + if (status) { + torchtrt::logging::log(torchtrt::logging::Level::kINFO, std::string("Unloaded custom library " + p.first)); + } else { + torchtrt::logging::log( + torchtrt::logging::Level::kERROR, std::string("Could not unload custom library " + p.first)); + } } } if (custom_converters) { for (auto& p : custom_converter_op) { - unload_library(p.second, p.first); + auto status = unload_library(p.second); + if (status) { + torchtrt::logging::log(torchtrt::logging::Level::kINFO, std::string("Unloaded custom library " + p.first)); + } else { + torchtrt::logging::log( + torchtrt::logging::Level::kERROR, std::string("Could not unload custom library " + p.first)); + } } } From 74f447500e62c0da4efd59c326a675c6291a34ee Mon Sep 17 00:00:00 2001 From: Anurag Dixit Date: Fri, 5 Aug 2022 19:41:51 -0700 Subject: [PATCH 3/3] chore: Review comments incorporated Signed-off-by: Anurag Dixit --- cpp/bin/torchtrtc/README.md | 13 +++++++++++++ cpp/bin/torchtrtc/main.cpp | 2 +- docsrc/tutorials/torchtrtc.rst | 12 ++++++++++++ 3 files changed, 26 insertions(+), 1 deletion(-) diff --git a/cpp/bin/torchtrtc/README.md b/cpp/bin/torchtrtc/README.md index 498f25ea17..6466ada390 100644 --- a/cpp/bin/torchtrtc/README.md +++ b/cpp/bin/torchtrtc/README.md @@ -108,6 +108,8 @@ torchtrtc [input_file_path] [output_file_path] TorchScript program, save the created engine to the path specified as the output path + --custom-torch-ops=[lib] (repeatable) Shared object/DLL containing custom torch operators + --custom-converters=[lib] (repeatable) Shared object/DLL containing custom converters input_file_path Path to input TorchScript file output_file_path Path for compiled TorchScript (or TensorRT engine) file @@ -131,3 +133,14 @@ e.g. ``` torchtrtc tests/modules/ssd_traced.jit.pt ssd_trt.ts "[(1,3,300,300); (1,3,512,512); (1, 3, 1024, 1024)]@fp16%contiguous" -p f16 ``` + + +To run with custom torch operators +``` +torchtrtc tests/modules/ssd_traced.jit.pt ssd_trt.ts --custom-torch-ops= "[(1,3,300,300); (1,3,512,512); (1, 3, 1024, 1024)]@fp16%contiguous" -p f16 +``` + +To run with custom converters +``` +torchtrtc tests/modules/ssd_traced.jit.pt ssd_trt.ts --custom-converters= "[(1,3,300,300); (1,3,512,512); (1, 3, 1024, 1024)]@fp16%contiguous" -p f16 +``` \ No newline at end of file diff --git a/cpp/bin/torchtrtc/main.cpp b/cpp/bin/torchtrtc/main.cpp index 1a004b7133..f98ed848de 100644 --- a/cpp/bin/torchtrtc/main.cpp +++ b/cpp/bin/torchtrtc/main.cpp @@ -176,7 +176,7 @@ int main(int argc, char** argv) { args::ValueFlagList custom_torch_ops( parser, "custom-torch-ops", - "(repeatable) Shared object/DLL containing custom torch operator", + "(repeatable) Shared object/DLL containing custom torch operators", {"custom-torch-ops"}); args::ValueFlagList custom_converters( diff --git a/docsrc/tutorials/torchtrtc.rst b/docsrc/tutorials/torchtrtc.rst index 5a2808bb8d..68f599a5cd 100644 --- a/docsrc/tutorials/torchtrtc.rst +++ b/docsrc/tutorials/torchtrtc.rst @@ -111,6 +111,8 @@ to standard TorchScript. Load with ``torch.jit.load()`` and run like you would r TorchScript program, save the created engine to the path specified as the output path + --custom-torch-ops (repeatable) Shared object/DLL containing custom torch operators + --custom-converters (repeatable) Shared object/DLL containing custom converters input_file_path Path to input TorchScript file output_file_path Path for compiled TorchScript (or TensorRT engine) file @@ -132,3 +134,13 @@ e.g. .. code-block:: shell torchtrtc tests/modules/ssd_traced.jit.pt ssd_trt.ts "[(1,3,300,300); (1,3,512,512); (1, 3, 1024, 1024)]@f16%contiguous" -p f16 + + +To run with custom torch operators +.. code-block:: shell +torchtrtc tests/modules/ssd_traced.jit.pt ssd_trt.ts --custom-torch-ops= "[(1,3,300,300); (1,3,512,512); (1, 3, 1024, 1024)]@fp16%contiguous" -p f16 + + +To run with custom converters +.. code-block:: shell +torchtrtc tests/modules/ssd_traced.jit.pt ssd_trt.ts --custom-converters= "[(1,3,300,300); (1,3,512,512); (1, 3, 1024, 1024)]@fp16%contiguous" -p f16