From a234335f25f5d633091e0ed8bad902a384c2a09a Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Mon, 18 Oct 2021 13:28:15 -0700 Subject: [PATCH] feat!: Changing the default behavior for selecting the input type BREAKING CHANGE: This commit changes the default behavior of the compiler where if the user does not specify an input data type explicity instead of using the enabled precision, now the compiler will inspect the model provided to infer the data type for the input that will not cause an error if the model was run in torch. In practice this means - If the weights are in FP32 for the first tensor calculation then default input type is FP32 - If the weights are in FP16 for the first tensor calculation then default input type is FP16 - etc. If the data type cannot be determined the compiler will default to FP32. This calculation is done per input tensor so if one input is inferred to use FP32 and another INT32 then the expected types will be the same (FP32, INT32) As was the same before if the user defines the data type explicitly or provides an example tensor the data type specified there will be respected Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- core/compiler.cpp | 56 ++++++++------ core/lowering/lowering.cpp | 1 + core/util/jit_util.cpp | 5 +- core/util/jit_util.h | 5 +- core/util/logging/TRTorchLogger.cpp | 2 +- cpp/include/trtorch/trtorch.h | 15 ++-- cpp/src/compile_spec.cpp | 20 ++--- py/trtorch/Input.py | 44 +++++++++-- py/trtorch/csrc/tensorrt_classes.cpp | 2 + py/trtorch/csrc/tensorrt_classes.h | 2 +- py/trtorch/csrc/trtorch_py.cpp | 1 + tests/core/test_detecting_input_type.cpp | 51 +++++++++++++ tests/cpp/test_default_input_types.cpp | 95 +++++++++++++++++++++++- tests/py/test_api.py | 82 ++++++++++++++++++-- 14 files changed, 310 insertions(+), 71 deletions(-) create mode 100644 tests/core/test_detecting_input_type.cpp diff --git a/core/compiler.cpp b/core/compiler.cpp index ccccc512ae..1095a88587 100644 --- a/core/compiler.cpp +++ b/core/compiler.cpp @@ -287,22 +287,45 @@ GraphAndMapping ConstructFallbackGraph( return {new_g, old_to_new_g}; } + +void MapInputsAndDetermineDTypes(CompileSpec& cfg, std::shared_ptr& g, ir::StaticParams& static_params, const util::InputTypeMap& first_use_type_map) { + // Associate input specs with inputs + cfg.convert_info.inputs = std::move(ir::associate_specs_with_inputs(g, cfg.inputs, static_params)); + + for (auto& in : g->inputs()) { + auto est_type_opt = first_use_type_map.find(in)->second; + ir::Input& spec = cfg.convert_info.inputs.find(in)->second; + if (est_type_opt && !spec.dtype_is_user_defined) { + // If we can calculate the type from the graph and the type was not defined by the user then use the calculated type + LOG_INFO("Since input type is not explicitly defined, infering using first tensor calculation\n Found input " + << in->debugName() << " has type " << est_type_opt.value() << ". If this is incorrect explicitly set dtype for input and file a bug"); + spec.dtype = util::ScalarTypeToTRTDataType(est_type_opt.value()); + } else if (!est_type_opt && !spec.dtype_is_user_defined) { + // If we cannot calculate the type and the user did not define the type, then default to FP32 + LOG_WARNING( + "Cannot deterime input type from calcuations in graph for input " + << in->debugName() << ". Assuming it is Float32. If not, specify input type explicity"); + spec.dtype = nvinfer1::DataType::kFLOAT; + } else { + // The user defined the type so no changes are necessary + } + } +} + std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg) { // Go through Lowering to simplify graph and extract weight parameters auto graph_and_parameters = lowering::Lower(mod, method_name, cfg.lower_info); - auto convert_cfg = std::move(cfg.convert_info); auto g = graph_and_parameters.first; - auto params = graph_and_parameters.second; auto static_params = ir::get_static_params(g->inputs(), params); + // Infer the type of an input from the weights of the calculation + auto first_use_types = util::get_block_first_calc_dtypes_opt(g->block()); - LOG_INFO(*g << "(CompileGraph)\n"); + MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types); - // Move the user defined inputs to the convert_cfg since some might be static; - convert_cfg.inputs = std::move(ir::associate_specs_with_inputs(g, cfg.inputs, static_params)); + auto engine = conversion::ConvertBlockToEngine(g->block(), cfg.convert_info, static_params); - auto engine = conversion::ConvertBlockToEngine(g->block(), convert_cfg, static_params); return std::move(engine); } @@ -331,27 +354,12 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg) auto graph_and_parameters = lowering::Lower(mod, method.name(), cfg.lower_info); auto g = graph_and_parameters.first; - LOG_INFO("Lowered Graph: " << *g); auto params = graph_and_parameters.second; auto static_params = ir::get_static_params(g->inputs(), params); - - cfg.convert_info.inputs = std::move(ir::associate_specs_with_inputs(g, cfg.inputs, static_params)); - - // If the user did not explicitly set the input type, then use the first - // tensor calculation to infer type. + // Infer the type of an input from the weights of the calculation auto first_use_types = util::get_block_first_calc_dtypes_opt(g->block()); - for (auto& in : g->inputs()) { - auto est_type_opt = first_use_types[in]; - ir::Input& spec = cfg.convert_info.inputs.find(in)->second; - if (est_type_opt && !spec.dtype_is_user_defined) { - spec.dtype = util::ScalarTypeToTRTDataType(est_type_opt.value()); - } else if (!est_type_opt && !spec.dtype_is_user_defined) { - LOG_WARNING( - "Cannot deterime input type from calcuations in graph for input " - << in->debugName() << ". Assuming it is Float32. If not, specify input type explicity"); - spec.dtype = nvinfer1::DataType::kFLOAT; - } - } + + MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types); if (cfg.partition_info.enabled) { auto input_ivalues_map = partitioning::generateRandomInputs(cfg.convert_info.inputs, first_use_types); diff --git a/core/lowering/lowering.cpp b/core/lowering/lowering.cpp index 4be3e403aa..232dc08e19 100644 --- a/core/lowering/lowering.cpp +++ b/core/lowering/lowering.cpp @@ -97,6 +97,7 @@ std::pair, std::vector> L // Is this necessary? // lowering::LowerBlock(g->block()); + LOG_INFO("Lowered Graph: " << *(graph_and_ivalues.first)); return graph_and_ivalues; } diff --git a/core/util/jit_util.cpp b/core/util/jit_util.cpp index 128546e5f8..91bb1bede0 100644 --- a/core/util/jit_util.cpp +++ b/core/util/jit_util.cpp @@ -96,9 +96,8 @@ c10::optional get_value_first_calc_dtype_opt(torch::jit::Block* return dtype; } -std::unordered_map> get_block_first_calc_dtypes_opt( - torch::jit::Block* b) { - std::unordered_map> types; +InputTypeMap get_block_first_calc_dtypes_opt(torch::jit::Block* b) { + InputTypeMap types; for (auto i : b->inputs()) { if (i->type() == c10::TensorType::get()) { diff --git a/core/util/jit_util.h b/core/util/jit_util.h index 7fa0739873..082441eeb1 100644 --- a/core/util/jit_util.h +++ b/core/util/jit_util.h @@ -9,6 +9,8 @@ namespace trtorch { namespace core { namespace util { +using InputTypeMap = std::unordered_map>; + inline std::string node_info(const torch::jit::Node* n) { std::stringstream ss; ss << *n; @@ -61,8 +63,7 @@ inline std::string GetPyTorchSourceCode(const torch::jit::Node* n) { } c10::optional get_value_first_calc_dtype_opt(torch::jit::Block* b, torch::jit::Value* in); -std::unordered_map> get_block_first_calc_dtypes_opt( - torch::jit::Block* b); +InputTypeMap get_block_first_calc_dtypes_opt(torch::jit::Block* b); } // namespace util } // namespace core diff --git a/core/util/logging/TRTorchLogger.cpp b/core/util/logging/TRTorchLogger.cpp index 1be2cfa3ce..0f7030193a 100644 --- a/core/util/logging/TRTorchLogger.cpp +++ b/core/util/logging/TRTorchLogger.cpp @@ -125,7 +125,7 @@ namespace { TRTorchLogger& get_global_logger() { #ifndef NDEBUG - static TRTorchLogger global_logger("[TRTorch - Debug Build] - ", LogLevel::kDEBUG, true); + static TRTorchLogger global_logger("[TRTorch - Debug Build] - ", LogLevel::kGRAPH, true); #else static TRTorchLogger global_logger("[TRTorch] - ", LogLevel::kERROR, false); #endif diff --git a/cpp/include/trtorch/trtorch.h b/cpp/include/trtorch/trtorch.h index 0c298a3321..874221e523 100644 --- a/cpp/include/trtorch/trtorch.h +++ b/cpp/include/trtorch/trtorch.h @@ -387,7 +387,7 @@ struct TRTORCH_API CompileSpec { * / traditional TRT convection (FP32 for FP32 only, FP16 for FP32 and FP16, FP32 for Int8) * * @param shape Input tensor shape - * @param dtype Expected data type for the input (Defaults to Float32) + * @param dtype Expected data type for the input (Defaults to the type of the weights in the first tensor calculation if detectable else Float32) * @param format Expected tensor format for the input (Defaults to contiguous) */ Input(std::vector shape, TensorFormat format = TensorFormat::kContiguous); @@ -398,7 +398,7 @@ struct TRTORCH_API CompileSpec { * tensor format * * @param shape Input tensor shape - * @param dtype Expected data type for the input (Defaults to Float32) + * @param dtype Expected data type for the input (Defaults to the type of the weights in the first tensor calculation if detectable else Float32) * @param format Expected tensor format for the input (Defaults to contiguous) */ Input(std::vector shape, DataType dtype, TensorFormat format = TensorFormat::kContiguous); @@ -421,7 +421,7 @@ struct TRTORCH_API CompileSpec { * allow the user to configure expected input shape tensor format * * @param shape Input tensor shape - * @param dtype Expected data type for the input (Defaults to Float32) + * @param dtype Expected data type for the input (Defaults to the type of the weights in the first tensor calculation if detectable else Float32) * @param format Expected tensor format for the input (Defaults to contiguous) */ Input(c10::ArrayRef shape, DataType dtype, TensorFormat format = TensorFormat::kContiguous); @@ -451,7 +451,7 @@ struct TRTORCH_API CompileSpec { * @param min_shape Minimum shape for input tensor * @param opt_shape Target optimization shape for input tensor * @param max_shape Maximum acceptible shape for input tensor - * @param dtype Expected data type for the input (Defaults to Float32) + * @param dtype Expected data type for the input (Defaults to the type of the weights in the first tensor calculation if detectable else Float32) * @param format Expected tensor format for the input (Defaults to contiguous) */ Input( @@ -486,7 +486,7 @@ struct TRTORCH_API CompileSpec { * @param min_shape Minimum shape for input tensor * @param opt_shape Target optimization shape for input tensor * @param max_shape Maximum acceptible shape for input tensor - * @param dtype Expected data type for the input (Defaults to Float32) + * @param dtype Expected data type for the input (Defaults to the type of the weights in the first tensor calculation if detectable else Float32) * @param format Expected tensor format for the input (Defaults to contiguous) */ Input( @@ -506,14 +506,9 @@ struct TRTORCH_API CompileSpec { */ Input(at::Tensor tensor); - bool get_explicit_set_dtype() { - return explicit_set_dtype; - } - private: friend std::ostream& operator<<(std::ostream& os, const Input& input); bool input_is_dynamic; - bool explicit_set_dtype; }; /** diff --git a/cpp/src/compile_spec.cpp b/cpp/src/compile_spec.cpp index ff28bd2fe9..04bec052f6 100644 --- a/cpp/src/compile_spec.cpp +++ b/cpp/src/compile_spec.cpp @@ -73,7 +73,6 @@ std::ostream& operator<<(std::ostream& os, const CompileSpec::Input& input) { } nvinfer1::DataType toTRTDataType(CompileSpec::DataType value) { - TRTORCH_CHECK(!(value == CompileSpec::DataType::kUnknown), "Data type is unknown"); switch (value) { case CompileSpec::DataType::kChar: return nvinfer1::DataType::kINT8; @@ -162,8 +161,7 @@ CompileSpec::Input::Input(std::vector shape, TensorFormat format) { this->min_shape = shape; this->max_shape = shape; this->shape = shape; - this->dtype = dtype; - this->explicit_set_dtype = false; + this->dtype = CompileSpec::DataType::kUnknown; this->format = format; this->input_is_dynamic = false; } @@ -174,7 +172,6 @@ CompileSpec::Input::Input(std::vector shape, DataType dtype, TensorForm this->max_shape = shape; this->shape = shape; this->dtype = dtype; - this->explicit_set_dtype = true; this->format = format; this->input_is_dynamic = false; } @@ -184,8 +181,7 @@ CompileSpec::Input::Input(c10::IntArrayRef shape, TensorFormat format) { this->min_shape = core::util::toVec(shape); this->max_shape = core::util::toVec(shape); this->shape = core::util::toVec(shape); - this->dtype = DataType::kFloat; - this->explicit_set_dtype = false; + this->dtype = CompileSpec::DataType::kUnknown; this->format = format; this->input_is_dynamic = false; } @@ -196,7 +192,6 @@ CompileSpec::Input::Input(c10::IntArrayRef shape, DataType dtype, TensorFormat f this->max_shape = core::util::toVec(shape); this->shape = core::util::toVec(shape); this->dtype = dtype; - this->explicit_set_dtype = true; this->format = format; this->input_is_dynamic = false; } @@ -210,8 +205,7 @@ CompileSpec::Input::Input( this->min_shape = min_shape; this->max_shape = max_shape; this->shape = core::util::toVec(core::ir::Input(this->min_shape, this->opt_shape, this->max_shape).input_shape); - this->dtype = dtype; - this->explicit_set_dtype = false; + this->dtype = CompileSpec::DataType::kUnknown; this->format = format; this->input_is_dynamic = true; } @@ -227,7 +221,6 @@ CompileSpec::Input::Input( this->max_shape = max_shape; this->shape = core::util::toVec(core::ir::Input(this->min_shape, this->opt_shape, this->max_shape).input_shape); this->dtype = dtype; - this->explicit_set_dtype = true; this->format = format; this->input_is_dynamic = true; } @@ -241,8 +234,7 @@ CompileSpec::Input::Input( this->min_shape = core::util::toVec(min_shape); this->max_shape = core::util::toVec(max_shape); this->shape = core::util::toVec(core::ir::Input(this->min_shape, this->opt_shape, this->max_shape).input_shape); - this->dtype = dtype; - this->explicit_set_dtype = false; + this->dtype = CompileSpec::DataType::kUnknown; this->format = format; this->input_is_dynamic = true; } @@ -258,7 +250,6 @@ CompileSpec::Input::Input( this->max_shape = core::util::toVec(max_shape); this->shape = core::util::toVec(core::ir::Input(this->min_shape, this->opt_shape, this->max_shape).input_shape); this->dtype = dtype; - this->explicit_set_dtype = true; this->format = format; this->input_is_dynamic = true; } @@ -269,7 +260,6 @@ CompileSpec::Input::Input(at::Tensor tensor) { this->max_shape = tensor.sizes().vec(); this->shape = tensor.sizes().vec(); this->dtype = tensor.scalar_type(); - this->explicit_set_dtype = true; TRTORCH_ASSERT( tensor.is_contiguous(at::MemoryFormat::ChannelsLast) || tensor.is_contiguous(at::MemoryFormat::Contiguous), "Tensor does not have a supported contiguous memory format, supported formats are contiguous or channel_last"); @@ -292,7 +282,7 @@ core::ir::Input to_internal_input(CompileSpec::Input& i) { i.max_shape, toTRTDataType(i.dtype), toTRTTensorFormat(i.format), - i.get_explicit_set_dtype()); + !(i.dtype == CompileSpec::DataType::kUnknown)); } std::vector to_vec_internal_inputs(std::vector& external) { diff --git a/py/trtorch/Input.py b/py/trtorch/Input.py index 51cf4f6860..27daab0a27 100644 --- a/py/trtorch/Input.py +++ b/py/trtorch/Input.py @@ -30,7 +30,7 @@ class _ShapeMode(Enum): shape_mode = None #: (trtorch.Input._ShapeMode): Is input statically or dynamically shaped shape = None #: (Tuple or Dict): Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }`` - dtype = _types.dtype.float32 #: The expected data type of the input tensor (default: trtorch.dtype.float32) + dtype = _types.dtype.unknown #: The expected data type of the input tensor (default: trtorch.dtype.float32) _explicit_set_dtype = False format = _types.TensorFormat.contiguous #: The expected format of the input tensor (default: trtorch.TensorFormat.NCHW) @@ -133,16 +133,44 @@ def __str__(self) -> str: def _to_internal(self) -> trtorch._C.Input: internal_in = trtorch._C.Input() if self.shape_mode == Input._ShapeMode.DYNAMIC: - internal_in.min = self.shape["min_shape"] - internal_in.opt = self.shape["opt_shape"] - internal_in.max = self.shape["max_shape"] + if not Input._supported_input_size_type(self.shape["min_shape"]): + raise TypeError( + "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " + + str(type(self.shape["min_shape"])) + " for min_shape") + else: + internal_in.min = self.shape["min_shape"] + + if not Input._supported_input_size_type(self.shape["opt_shape"]): + raise TypeError( + "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " + + str(type(self.shape["opt_shape"])) + " for opt_shape") + else: + internal_in.min = self.shape["op_shape"] + + if not Input._supported_input_size_type(self.shape["max_shape"]): + raise TypeError( + "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " + + str(type(self.shape["max_shape"])) + " for max_shape") + else: + internal_in.min = self.shape["opt_shape"] internal_in.input_is_dynamic = True else: - internal_in.opt = self.shape + if not Input._supported_input_size_type(self.shape): + raise TypeError( + "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " + + str(type(self.shape)) + " for shape") + else: + internal_in.opt = self.shape internal_in.input_is_dynamic = False - internal_in.dtype = self.dtype + + if self.dtype != _types.dtype.unknown: + self._explicit_set_dtype = True + else: + self._explicit_set_dtype = False + + internal_in.dtype = Input._parse_dtype(self.dtype) internal_in._explicit_set_dtype = self._explicit_set_dtype - internal_in.format = self.format + internal_in.format = Input._parse_format(self.format) return internal_in @staticmethod @@ -172,7 +200,7 @@ def _parse_dtype(dtype: Any) -> _types.dtype: "Provided an unsupported data type as an input data type (support: bool, int32, half, float), got: " + str(dtype)) - elif isinstance(dtype, _types.DataTypes): + elif isinstance(dtype, _types.dtype): return dtype else: diff --git a/py/trtorch/csrc/tensorrt_classes.cpp b/py/trtorch/csrc/tensorrt_classes.cpp index cf037575a1..913493a414 100644 --- a/py/trtorch/csrc/tensorrt_classes.cpp +++ b/py/trtorch/csrc/tensorrt_classes.cpp @@ -33,6 +33,8 @@ nvinfer1::DataType toTRTDataType(DataType value) { return nvinfer1::DataType::kBOOL; case DataType::kFloat: return nvinfer1::DataType::kFLOAT; + case DataType::kUnknown: + return nvinfer1::DataType::kFLOAT; default: TRTORCH_THROW_ERROR("Unknown data type: " << to_str(value)); } diff --git a/py/trtorch/csrc/tensorrt_classes.h b/py/trtorch/csrc/tensorrt_classes.h index b7b5d08873..815c2a0ce4 100644 --- a/py/trtorch/csrc/tensorrt_classes.h +++ b/py/trtorch/csrc/tensorrt_classes.h @@ -27,7 +27,7 @@ namespace pyapi { return static_cast(field_name); \ } -enum class DataType : int8_t { kFloat, kHalf, kChar, kInt32, kBool }; +enum class DataType : int8_t { kFloat, kHalf, kChar, kInt32, kBool, kUnknown }; std::string to_str(DataType value); nvinfer1::DataType toTRTDataType(DataType value); diff --git a/py/trtorch/csrc/trtorch_py.cpp b/py/trtorch/csrc/trtorch_py.cpp index e8d3c9e696..beffc10dd3 100644 --- a/py/trtorch/csrc/trtorch_py.cpp +++ b/py/trtorch/csrc/trtorch_py.cpp @@ -186,6 +186,7 @@ PYBIND11_MODULE(_C, m) { .value("int8", DataType::kChar, "8 bit integer number") .value("int32", DataType::kInt32, "32 bit integer number") .value("bool", DataType::kChar, "Boolean value") + .value("unknown", DataType::kUnknown, "Unknown data type") .export_values(); py::enum_(m, "DeviceType", "Enum to specify device kinds to build TensorRT engines for") diff --git a/tests/core/test_detecting_input_type.cpp b/tests/core/test_detecting_input_type.cpp new file mode 100644 index 0000000000..c7a279d38a --- /dev/null +++ b/tests/core/test_detecting_input_type.cpp @@ -0,0 +1,51 @@ +#include +#include "gtest/gtest.h" +#include "tests/util/util.h" +#include "torch/script.h" +#include "core/util/prelude.h" +#include "core/lowering/lowering.h" +#include "trtorch/trtorch.h" + +TEST(CoreTest, DetectingInputTypeWorksCorrectFP32) { + torch::jit::script::Module mod; + try { + mod = torch::jit::load("tests/modules/mobilenet_v2_scripted.jit.pt"); + } catch (const c10::Error& e) { + std::cerr << "error loading the model\n"; + ASSERT_TRUE(false); + } + + auto graph_and_parameters = trtorch::core::lowering::Lower(mod, "forward", {}); + auto g = graph_and_parameters.first; + + auto input_types = trtorch::core::util::get_block_first_calc_dtypes_opt(g->block()); + + for (auto in : input_types) { + c10::optional& detected_type_opt = in.second; + ASSERT_TRUE(detected_type_opt); + ASSERT_TRUE(detected_type_opt.value() == at::kFloat); + } +} + +TEST(CoreTest, DetectingInputTypeWorksCorrectFP16) { + torch::jit::script::Module mod; + try { + mod = torch::jit::load("tests/modules/mobilenet_v2_scripted.jit.pt"); + } catch (const c10::Error& e) { + std::cerr << "error loading the model\n"; + ASSERT_TRUE(false); + } + + mod.to(at::kHalf); + + auto graph_and_parameters = trtorch::core::lowering::Lower(mod, "forward", {}); + auto g = graph_and_parameters.first; + + auto input_types = trtorch::core::util::get_block_first_calc_dtypes_opt(g->block()); + + for (auto in : input_types) { + c10::optional& detected_type_opt = in.second; + ASSERT_TRUE(detected_type_opt); + ASSERT_TRUE(detected_type_opt.value() == at::kHalf); + } +} diff --git a/tests/cpp/test_default_input_types.cpp b/tests/cpp/test_default_input_types.cpp index 1522126791..9fd37d936a 100644 --- a/tests/cpp/test_default_input_types.cpp +++ b/tests/cpp/test_default_input_types.cpp @@ -1,11 +1,34 @@ #include "cpp_api_test.h" +#include "trtorch/logging.h" + +TEST_P(CppAPITests, InputsUseDefaultFP32) { + trtorch::logging::set_reportable_log_level(trtorch::logging::Level::kINFO); + + std::vector jit_inputs_ivalues; + std::vector trt_inputs_ivalues; + for (auto in_shape : input_shapes) { + auto in = at::randn(in_shape, {at::kCUDA}); + trt_inputs_ivalues.push_back(in.clone()); + } + + auto in = trtorch::CompileSpec::Input(input_shapes[0]); + auto spec = trtorch::CompileSpec({in}); + spec.enabled_precisions.insert(trtorch::CompileSpec::DataType::kHalf); + + auto trt_mod = trtorch::CompileGraph(mod, spec); + torch::jit::IValue trt_results_ivalues = trtorch::tests::util::RunModuleForward(trt_mod, trt_inputs_ivalues); + std::vector trt_results; + trt_results.push_back(trt_results_ivalues.toTensor()); + // If exits without error successfully defaults to FP32 +} + +TEST_P(CppAPITests, InputsUseDefaultFP16) { + trtorch::logging::set_reportable_log_level(trtorch::logging::Level::kINFO); -TEST_P(CppAPITests, InputsUseDefault) { std::vector jit_inputs_ivalues; std::vector trt_inputs_ivalues; for (auto in_shape : input_shapes) { auto in = at::randn(in_shape, {at::kCUDA}); - jit_inputs_ivalues.push_back(in.clone().to(torch::kHalf)); trt_inputs_ivalues.push_back(in.clone().to(torch::kHalf)); } @@ -22,6 +45,74 @@ TEST_P(CppAPITests, InputsUseDefault) { // If exits without error successfully defaults to FP16 } +TEST_P(CppAPITests, InputsUseDefaultFP16WithoutFP16Enabled) { + trtorch::logging::set_reportable_log_level(trtorch::logging::Level::kINFO); + + std::vector jit_inputs_ivalues; + std::vector trt_inputs_ivalues; + for (auto in_shape : input_shapes) { + auto in = at::randn(in_shape, {at::kCUDA}); + trt_inputs_ivalues.push_back(in.clone().to(torch::kHalf)); + } + + auto in = trtorch::CompileSpec::Input(input_shapes[0]); + auto spec = trtorch::CompileSpec({in}); + + mod.to(torch::kHalf); + + auto trt_mod = trtorch::CompileGraph(mod, spec); + torch::jit::IValue trt_results_ivalues = trtorch::tests::util::RunModuleForward(trt_mod, trt_inputs_ivalues); + std::vector trt_results; + trt_results.push_back(trt_results_ivalues.toTensor()); + // If exits without error successfully defaults to FP16 +} + +TEST_P(CppAPITests, InputsRespectUserSettingFP16WeightsFP32In) { + trtorch::logging::set_reportable_log_level(trtorch::logging::Level::kINFO); + + std::vector jit_inputs_ivalues; + std::vector trt_inputs_ivalues; + for (auto in_shape : input_shapes) { + auto in = at::randn(in_shape, {at::kCUDA}); + trt_inputs_ivalues.push_back(in.clone()); + } + + auto in = trtorch::CompileSpec::Input(input_shapes[0]); + in.dtype = torch::kF32; + auto spec = trtorch::CompileSpec({in}); + spec.enabled_precisions.insert(trtorch::CompileSpec::DataType::kHalf); + + mod.to(torch::kHalf); + + auto trt_mod = trtorch::CompileGraph(mod, spec); + torch::jit::IValue trt_results_ivalues = trtorch::tests::util::RunModuleForward(trt_mod, trt_inputs_ivalues); + std::vector trt_results; + trt_results.push_back(trt_results_ivalues.toTensor()); + // If exits without error successfully defaults to FP16 +} + +TEST_P(CppAPITests, InputsRespectUserSettingFP32WeightsFP16In) { + trtorch::logging::set_reportable_log_level(trtorch::logging::Level::kINFO); + + std::vector jit_inputs_ivalues; + std::vector trt_inputs_ivalues; + for (auto in_shape : input_shapes) { + auto in = at::randn(in_shape, {at::kCUDA}); + trt_inputs_ivalues.push_back(in.clone().to(torch::kHalf)); + } + + auto in = trtorch::CompileSpec::Input(input_shapes[0]); + in.dtype = torch::kF16; + auto spec = trtorch::CompileSpec({in}); + spec.enabled_precisions.insert(trtorch::CompileSpec::DataType::kHalf); + + auto trt_mod = trtorch::CompileGraph(mod, spec); + torch::jit::IValue trt_results_ivalues = trtorch::tests::util::RunModuleForward(trt_mod, trt_inputs_ivalues); + std::vector trt_results; + trt_results.push_back(trt_results_ivalues.toTensor()); + // If exits without error successfully defaults to FP16 +} + INSTANTIATE_TEST_SUITE_P( CompiledModuleForwardIsCloseSuite, CppAPITests, diff --git a/tests/py/test_api.py b/tests/py/test_api.py index c28cdaa27b..241f9a7609 100644 --- a/tests/py/test_api.py +++ b/tests/py/test_api.py @@ -2,6 +2,7 @@ import trtorch import torch import torchvision.models as models +import copy from model_test_case import ModelTestCase @@ -75,8 +76,6 @@ def test_compile_script_from_dict(self): self.assertTrue(same < 2e-2) - - class TestCompileHalf(ModelTestCase): def setUp(self): @@ -135,7 +134,6 @@ def test_compile_script(self): "device": { "device_type": trtorch.DeviceType.GPU, "gpu_id": 0, - "dla_core": 0, "allow_gpu_fallback": False, "disable_tf32": False }, @@ -161,7 +159,6 @@ def test_compile_script(self): "device": { "device_type": trtorch.DeviceType.GPU, "gpu_id": 0, - "dla_core": 0, "allow_gpu_fallback": False, "disable_tf32": False }, @@ -187,7 +184,6 @@ def test_pt_to_trt_to_pt(self): "device": { "device_type": trtorch.DeviceType.GPU, "gpu_id": 0, - "dla_core": 0, "allow_gpu_fallback": False, "disable_tf32": False } @@ -199,6 +195,80 @@ def test_pt_to_trt_to_pt(self): self.assertTrue(same < 2e-3) +class TestInputTypeDefaultsFP32Model(ModelTestCase): + + def setUp(self): + self.input = torch.randn((1, 3, 224, 224)).to("cuda") + + def test_input_use_default_fp32(self): + ts_model = torch.jit.script(self.model) + trt_mod = trtorch.compile(ts_model, + inputs=[trtorch.Input(self.input.shape)], + enabled_precisions={torch.float, torch.half}) + trt_mod(self.input) + + def test_input_respect_user_setting_fp32_weights_fp16_in(self): + ts_model = torch.jit.script(self.model) + trt_mod = trtorch.compile(ts_model, + inputs=[self.input.half()], + enabled_precisions={torch.float, torch.half}) + trt_mod(self.input.half()) + + def test_input_respect_user_setting_fp32_weights_fp16_in_non_constructor(self): + ts_model = torch.jit.script(self.model) + input_spec = trtorch.Input(self.input.shape) + input_spec.dtype = torch.half + + trt_mod = trtorch.compile(ts_model, + inputs=[input_spec], + enabled_precisions={torch.float, torch.half}) + trt_mod(self.input.half()) + + +class TestInputTypeDefaultsFP16Model(ModelTestCase): + + def setUp(self): + self.input = torch.randn((1, 3, 224, 224)).to("cuda") + + def test_input_use_default_fp16(self): + half_mod = torch.jit.script(self.model) + half_mod.half() + + trt_mod = trtorch.compile(half_mod, + inputs=[trtorch.Input(self.input.shape)], + enabled_precisions={torch.float, torch.half}) + trt_mod(self.input.half()) + + def test_input_use_default_fp16_without_fp16_enabled(self): + half_mod = torch.jit.script(self.model) + half_mod.half() + + trt_mod = trtorch.compile(half_mod, + inputs=[trtorch.Input(self.input.shape)]) + trt_mod(self.input.half()) + + def test_input_respect_user_setting_fp16_weights_fp32_in(self): + half_mod = torch.jit.script(self.model) + half_mod.half() + + trt_mod = trtorch.compile(half_mod, + inputs=[self.input], + enabled_precisions={torch.float, torch.half}) + trt_mod(self.input) + + def test_input_respect_user_setting_fp16_weights_fp32_in_non_constuctor(self): + half_mod = torch.jit.script(self.model) + half_mod.half() + + input_spec = trtorch.Input(self.input.shape) + input_spec.dtype = torch.float + + trt_mod = trtorch.compile(half_mod, + inputs=[input_spec], + enabled_precisions={torch.float, torch.half}) + trt_mod(self.input) + + class TestCheckMethodOpSupport(unittest.TestCase): def setUp(self): @@ -284,6 +354,8 @@ def test_suite(): suite.addTest(TestCompileHalf.parametrize(TestCompileHalf, model=models.resnet18(pretrained=True))) suite.addTest(TestCompileHalfDefault.parametrize(TestCompileHalfDefault, model=models.resnet18(pretrained=True))) suite.addTest(TestPTtoTRTtoPT.parametrize(TestPTtoTRTtoPT, model=models.mobilenet_v2(pretrained=True))) + suite.addTest(TestInputTypeDefaultsFP32Model.parametrize(TestInputTypeDefaultsFP32Model, model=models.resnet18(pretrained=True))) + suite.addTest(TestInputTypeDefaultsFP16Model.parametrize(TestInputTypeDefaultsFP16Model, model=models.resnet18(pretrained=True))) suite.addTest(TestFallbackToTorch.parametrize(TestFallbackToTorch, model=models.resnet18(pretrained=True))) suite.addTest( TestModuleFallbackToTorch.parametrize(TestModuleFallbackToTorch, model=models.resnet18(pretrained=True)))