diff --git a/py/trtorch/_compile_spec.py b/py/trtorch/_compile_spec.py index cc894fdc4f..8bfe9dc355 100644 --- a/py/trtorch/_compile_spec.py +++ b/py/trtorch/_compile_spec.py @@ -259,37 +259,40 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt. backend_spec = torch.classes.tensorrt.CompileSpec() for i in parsed_spec.input_ranges: - ir = torch.classes.tensorrt.InputRange() - ir.set_min(i.min) - ir.set_opt(i.opt) - ir.set_max(i.max) - backend_spec.append_input_range(ir) - - d = torch.classes.tensorrt.Device() - d.set_device_type(int(parsed_spec.device.device_type)) - d.set_gpu_id(parsed_spec.device.gpu_id) - d.set_dla_core(parsed_spec.device.dla_core) - d.set_allow_gpu_fallback(parsed_spec.device.allow_gpu_fallback) - - torch_fallback = torch.classes.tensorrt.TorchFallback() - torch_fallback.set_enabled(parsed_spec.torch_fallback.enabled) - torch_fallback.set_min_block_size(parsed_spec.torch_fallback.min_block_size) - torch_fallback.set_forced_fallback_operators(parsed_spec.torch_fallback.forced_fallback_operators) - - backend_spec.set_device(d) - backend_spec.set_torch_fallback(fallback) - backend_spec.set_op_precision(int(parsed_spec.op_precision)) - backend_spec.set_disable_tf32(parsed_spec.disable_tf32) - backend_spec.set_refit(parsed_spec.refit) - backend_spec.set_debug(parsed_spec.debug) - backend_spec.set_refit(parsed_spec.refit) - backend_spec.set_strict_types(parsed_spec.strict_types) - backend_spec.set_capability(int(parsed_spec.capability)) - backend_spec.set_num_min_timing_iters(parsed_spec.num_min_timing_iters) - backend_spec.set_num_avg_timing_iters(parsed_spec.num_avg_timing_iters) - backend_spec.set_workspace_size(parsed_spec.workspace_size) - backend_spec.set_max_batch_size(parsed_spec.max_batch_size) - backend_spec.set_truncate_long_and_double(parsed_spec.truncate_long_and_double) + ir = torch.classes.tensorrt._InputRange() + ir._set_min(i.min) + ir._set_opt(i.opt) + ir._set_max(i.max) + backend_spec._append_input_range(ir) + + d = torch.classes.tensorrt._Device() + d._set_device_type(int(parsed_spec.device.device_type)) + d._set_gpu_id(parsed_spec.device.gpu_id) + d._set_dla_core(parsed_spec.device.dla_core) + d._set_allow_gpu_fallback(parsed_spec.device.allow_gpu_fallback) + + if parsed_spec.torch_fallback.enabled: + raise RuntimeError("Partial module compilation is not currently supported via the PyTorch to_backend API integration. If you need partial compilation, use trtorch.compile") + + torch_fallback = torch.classes.tensorrt._TorchFallback() + torch_fallback._set_enabled(parsed_spec.torch_fallback.enabled) + torch_fallback._set_min_block_size(parsed_spec.torch_fallback.min_block_size) + torch_fallback._set_forced_fallback_operators(parsed_spec.torch_fallback.forced_fallback_operators) + + backend_spec._set_device(d) + backend_spec._set_torch_fallback(torch_fallback) + backend_spec._set_op_precision(int(parsed_spec.op_precision)) + backend_spec._set_disable_tf32(parsed_spec.disable_tf32) + backend_spec._set_refit(parsed_spec.refit) + backend_spec._set_debug(parsed_spec.debug) + backend_spec._set_refit(parsed_spec.refit) + backend_spec._set_strict_types(parsed_spec.strict_types) + backend_spec._set_capability(int(parsed_spec.capability)) + backend_spec._set_num_min_timing_iters(parsed_spec.num_min_timing_iters) + backend_spec._set_num_avg_timing_iters(parsed_spec.num_avg_timing_iters) + backend_spec._set_workspace_size(parsed_spec.workspace_size) + backend_spec._set_max_batch_size(parsed_spec.max_batch_size) + backend_spec._set_truncate_long_and_double(parsed_spec.truncate_long_and_double) backend_spec._set_ptq_calibrator(parsed_spec._get_calibrator_handle()) return backend_spec diff --git a/py/trtorch/csrc/register_tensorrt_classes.cpp b/py/trtorch/csrc/register_tensorrt_classes.cpp index 2df85d2805..db5e522750 100644 --- a/py/trtorch/csrc/register_tensorrt_classes.cpp +++ b/py/trtorch/csrc/register_tensorrt_classes.cpp @@ -5,19 +5,23 @@ namespace backend { namespace { #define ADD_FIELD_GET_SET_REGISTRATION(registry, class_name, field_name) \ - (registry).def("set_" #field_name, &class_name::set_##field_name); \ - (registry).def("get_" #field_name, &class_name::get_##field_name); + (registry).def("_set_" #field_name, &class_name::set_##field_name); \ + (registry).def("_get_" #field_name, &class_name::get_##field_name); void RegisterTRTCompileSpec() { static auto TRTORCH_UNUSED TRTInputRangeTSRegistration = - torch::class_("tensorrt", "InputRange").def(torch::init<>()); + torch::class_("tensorrt", "_InputRange") + .def(torch::init<>()) + .def("__str__", &trtorch::pyapi::InputRange::to_str); ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::InputRange, min); ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::InputRange, opt); ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::InputRange, max); static auto TRTORCH_UNUSED TRTDeviceTSRegistration = - torch::class_("tensorrt", "Device").def(torch::init<>()); + torch::class_("tensorrt", "_Device") + .def(torch::init<>()) + .def("__str__", &trtorch::pyapi::Device::to_str); ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, device_type); ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, gpu_id); @@ -25,7 +29,10 @@ void RegisterTRTCompileSpec() { ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, allow_gpu_fallback); static auto TRTORCH_UNUSED TRTFallbackTSRegistration = - torch::class_("tensorrt", "Fallback").def(torch::init<>()); + torch::class_("tensorrt", "_TorchFallback") + .def(torch::init<>()) + .def("__str__", &trtorch::pyapi::TorchFallback::to_str); + ADD_FIELD_GET_SET_REGISTRATION(TRTFallbackTSRegistration, trtorch::pyapi::TorchFallback, enabled); ADD_FIELD_GET_SET_REGISTRATION(TRTFallbackTSRegistration, trtorch::pyapi::TorchFallback, min_block_size); ADD_FIELD_GET_SET_REGISTRATION(TRTFallbackTSRegistration, trtorch::pyapi::TorchFallback, forced_fallback_operators); @@ -33,9 +40,9 @@ void RegisterTRTCompileSpec() { static auto TRTORCH_UNUSED TRTCompileSpecTSRegistration = torch::class_("tensorrt", "CompileSpec") .def(torch::init<>()) - .def("append_input_range", &trtorch::pyapi::CompileSpec::appendInputRange) - .def("set_device", &trtorch::pyapi::CompileSpec::setDeviceIntrusive) - .def("set_torch_fallback", &trtorch::pyapi::CompileSpec::setTorchFallbackIntrusive) + .def("_append_input_range", &trtorch::pyapi::CompileSpec::appendInputRange) + .def("_set_device", &trtorch::pyapi::CompileSpec::setDeviceIntrusive) + .def("_set_torch_fallback", &trtorch::pyapi::CompileSpec::setTorchFallbackIntrusive) .def("_set_ptq_calibrator", &trtorch::pyapi::CompileSpec::setPTQCalibratorViaHandle) .def("__str__", &trtorch::pyapi::CompileSpec::stringify); diff --git a/py/trtorch/csrc/tensorrt_classes.cpp b/py/trtorch/csrc/tensorrt_classes.cpp index 771e516616..de0955c751 100644 --- a/py/trtorch/csrc/tensorrt_classes.cpp +++ b/py/trtorch/csrc/tensorrt_classes.cpp @@ -4,7 +4,7 @@ namespace trtorch { namespace pyapi { -std::string to_str(InputRange& value) { +std::string InputRange::to_str() { auto vec_to_str = [](std::vector shape) -> std::string { std::stringstream ss; ss << '['; @@ -17,9 +17,9 @@ std::string to_str(InputRange& value) { std::stringstream ss; ss << " {" << std::endl; - ss << " min: " << vec_to_str(value.min) << ',' << std::endl; - ss << " opt: " << vec_to_str(value.opt) << ',' << std::endl; - ss << " max: " << vec_to_str(value.max) << ',' << std::endl; + ss << " min: " << vec_to_str(min) << ',' << std::endl; + ss << " opt: " << vec_to_str(opt) << ',' << std::endl; + ss << " max: " << vec_to_str(max) << ',' << std::endl; ss << " }" << std::endl; return ss.str(); } @@ -68,6 +68,18 @@ nvinfer1::DeviceType toTRTDeviceType(DeviceType value) { } } +std::string Device::to_str() { + std::stringstream ss; + std::string fallback = allow_gpu_fallback ? "True" : "False"; + ss << " {" << std::endl; + ss << " \"device_type\": " << pyapi::to_str(device_type) << std::endl; + ss << " \"allow_gpu_fallback\": " << fallback << std::endl; + ss << " \"gpu_id\": " << gpu_id << std::endl; + ss << " \"dla_core\": " << dla_core << std::endl; + ss << " }" << std::endl; + return ss.str(); +} + std::string to_str(EngineCapability value) { switch (value) { case EngineCapability::kSAFE_GPU: @@ -92,6 +104,21 @@ nvinfer1::EngineCapability toTRTEngineCapability(EngineCapability value) { } } +std::string TorchFallback::to_str() { + std::stringstream ss; + std::string e = enabled ? "True" : "False"; + ss << " {" << std::endl; + ss << " \"enabled\": " << e << std::endl; + ss << " \"min_block_size\": " << min_block_size << std::endl; + ss << " \"forced_fallback_operators\": [" << std::endl; + for (auto i : forced_fallback_operators) { + ss << " " << i << ',' << std::endl; + } + ss << " ]" << std::endl; + ss << " }" << std::endl; + return ss.str(); +} + core::CompileSpec CompileSpec::toInternalCompileSpec() { std::vector internal_input_ranges; for (auto i : input_ranges) { @@ -128,36 +155,25 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() { std::string CompileSpec::stringify() { std::stringstream ss; ss << "TensorRT Compile Spec: {" << std::endl; - ss << " \"Input Shapes\": [" << std::endl; + ss << " \"Input Shapes\": [" << std::endl; for (auto i : input_ranges) { - ss << to_str(i); + ss << i.to_str(); } std::string enabled = torch_fallback.enabled ? "True" : "False"; - ss << " ]" << std::endl; - ss << " \"Op Precision\": " << to_str(op_precision) << std::endl; - ss << " \"TF32 Disabled\": " << disable_tf32 << std::endl; - ss << " \"Refit\": " << refit << std::endl; - ss << " \"Debug\": " << debug << std::endl; - ss << " \"Strict Types\": " << strict_types << std::endl; - ss << " \"Device Type: " << to_str(device.device_type) << std::endl; - ss << " \"GPU ID: " << device.gpu_id << std::endl; - ss << " \"DLA Core: " << device.dla_core << std::endl; - ss << " \"Allow GPU Fallback\": " << device.allow_gpu_fallback << std::endl; - ss << " \"Engine Capability\": " << to_str(capability) << std::endl; - ss << " \"Num Min Timing Iters\": " << num_min_timing_iters << std::endl; - ss << " \"Num Avg Timing Iters\": " << num_avg_timing_iters << std::endl; - ss << " \"Workspace Size\": " << workspace_size << std::endl; - ss << " \"Max Batch Size\": " << max_batch_size << std::endl; - ss << " \"Truncate long and double\": " << truncate_long_and_double << std::endl; - ss << " \"Torch Fallback: {" << std::endl; - ss << " \"enabled\": " << enabled << std::endl; - ss << " \"min_block_size\": " << torch_fallback.min_block_size << std::endl; - ss << " \"forced_fallback_operators\": [" << std::endl; - for (auto i : torch_fallback.forced_fallback_operators) { - ss << " " << i << ',' << std::endl; - } - ss << " ]" << std::endl; - ss << " }" << std::endl; + ss << " ]" << std::endl; + ss << " \"Op Precision\": " << to_str(op_precision) << std::endl; + ss << " \"TF32 Disabled\": " << disable_tf32 << std::endl; + ss << " \"Refit\": " << refit << std::endl; + ss << " \"Debug\": " << debug << std::endl; + ss << " \"Strict Types\": " << strict_types << std::endl; + ss << " \"Device\": " << device.to_str() << std::endl; + ss << " \"Engine Capability\": " << to_str(capability) << std::endl; + ss << " \"Num Min Timing Iters\": " << num_min_timing_iters << std::endl; + ss << " \"Num Avg Timing Iters\": " << num_avg_timing_iters << std::endl; + ss << " \"Workspace Size\": " << workspace_size << std::endl; + ss << " \"Max Batch Size\": " << max_batch_size << std::endl; + ss << " \"Truncate long and double\": " << truncate_long_and_double << std::endl; + ss << " \"Torch Fallback\": " << torch_fallback.to_str(); ss << "}"; return ss.str(); } diff --git a/py/trtorch/csrc/tensorrt_classes.h b/py/trtorch/csrc/tensorrt_classes.h index 024e73d06a..4dd7a8f9d4 100644 --- a/py/trtorch/csrc/tensorrt_classes.h +++ b/py/trtorch/csrc/tensorrt_classes.h @@ -39,9 +39,9 @@ struct InputRange : torch::CustomClassHolder { ADD_FIELD_GET_SET(min, std::vector); ADD_FIELD_GET_SET(opt, std::vector); ADD_FIELD_GET_SET(max, std::vector); -}; -std::string to_str(InputRange& value); + std::string to_str(); +}; enum class DataType : int8_t { kFloat, @@ -73,6 +73,8 @@ struct Device : torch::CustomClassHolder { ADD_FIELD_GET_SET(gpu_id, int64_t); ADD_FIELD_GET_SET(dla_core, int64_t); ADD_FIELD_GET_SET(allow_gpu_fallback, bool); + + std::string to_str(); }; std::string to_str(DeviceType value); @@ -87,8 +89,11 @@ struct TorchFallback : torch::CustomClassHolder { ADD_FIELD_GET_SET(enabled, bool); ADD_FIELD_GET_SET(min_block_size, int64_t); ADD_FIELD_GET_SET(forced_fallback_operators, std::vector); + + std::string to_str(); }; + enum class EngineCapability : int8_t { kDEFAULT, kSAFE_GPU, diff --git a/py/trtorch/csrc/trtorch_py.cpp b/py/trtorch/csrc/trtorch_py.cpp index 28dda72768..84e97a0ecb 100644 --- a/py/trtorch/csrc/trtorch_py.cpp +++ b/py/trtorch/csrc/trtorch_py.cpp @@ -165,6 +165,7 @@ void log(core::util::logging::LogLevel lvl, const std::string& msg) { PYBIND11_MODULE(_C, m) { py::class_(m, "InputRange") .def(py::init<>()) + .def("__str__", &trtorch::pyapi::InputRange::to_str) .def_readwrite("min", &InputRange::min) .def_readwrite("opt", &InputRange::opt) .def_readwrite("max", &InputRange::max); @@ -237,6 +238,7 @@ PYBIND11_MODULE(_C, m) { py::class_(m, "CompileSpec") .def(py::init<>()) + .def("__str__", &trtorch::pyapi::CompileSpec::stringify) .def("_get_calibrator_handle", &CompileSpec::getPTQCalibratorHandle, "[Internal] gets a handle from a calibrator") .def_readwrite("input_ranges", &CompileSpec::input_ranges) .def_readwrite("op_precision", &CompileSpec::op_precision) @@ -256,6 +258,7 @@ PYBIND11_MODULE(_C, m) { py::class_(m, "Device") .def(py::init<>()) + .def("__str__", &trtorch::pyapi::Device::to_str) .def_readwrite("device_type", &Device::device_type) .def_readwrite("gpu_id", &Device::gpu_id) .def_readwrite("dla_core", &Device::dla_core) @@ -263,6 +266,7 @@ PYBIND11_MODULE(_C, m) { py::class_(m, "TorchFallback") .def(py::init<>()) + .def("__str__", &trtorch::pyapi::TorchFallback::to_str) .def_readwrite("enabled", &TorchFallback::enabled) .def_readwrite("min_block_size", &TorchFallback::min_block_size) .def_readwrite("forced_fallback_operators", &TorchFallback::forced_fallback_operators);