Skip to content

Commit

Permalink
feat: Add support for providing input datatypes in TRTorch
Browse files Browse the repository at this point in the history
Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>

Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>
  • Loading branch information
peri044 authored and narendasan committed Jul 21, 2021
1 parent bdaacf1 commit a3f4a3c
Show file tree
Hide file tree
Showing 12 changed files with 125 additions and 28 deletions.
9 changes: 7 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@ More Information / System Architecture:
...
auto compile_settings = trtorch::CompileSpec(dims);
// FP16 execution
compile_settings.op_precision = torch::kFloat;
compile_settings.op_precision = torch::kHalf;
// Set input datatypes. Allowerd options torch::{kFloat, kHalf, kChar, kInt32, kBool}
// Size of input_dtypes should match number of inputs to the network.
// If input_dtypes is not set, default precision for input tensors would be float32
compile_spec.input_dtypes = {torch::kHalf};
// Compile module
auto trt_mod = trtorch::CompileGraph(ts_mod, compile_settings);
// Run like normal
Expand All @@ -43,7 +47,8 @@ compile_settings = {
"max": [1, 3, 1024, 1024]
}, # For static size [1, 3, 224, 224]
],
"op_precision": torch.half # Run with FP16
"op_precision": torch.half, # Run with FP16
"input_dtypes": [torch.half] # Datatype of input tensor. Allowed options torch.(float|half|int8|int32|bool)
}
trt_ts_module = trtorch.compile(torch_script_module, compile_settings)
Expand Down
19 changes: 16 additions & 3 deletions core/conversion/conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,14 +150,27 @@ void AddInputs(

auto profile = ctx->builder->createOptimizationProfile();

TRTORCH_CHECK(
ctx->input_dtypes.size() == 0 || ctx->input_dtypes.size() == input_tensors.size(),
"Number of input_dtypes : " << ctx->input_dtypes.size()
<< " should either be 0 or equal to number of input_tensors which is "
<< input_tensors.size() << " (conversion.AddInputs)");

// If the input_dtypes is not provided, assume all the input tensors to be in float32
if (ctx->input_dtypes.size() == 0) {
LOG_DEBUG("Input datatypes are not provided explicitly. Default float32 datatype is being used for all inputs");
ctx->input_dtypes = std::vector<nvinfer1::DataType>{input_tensors.size(), nvinfer1::DataType::kFLOAT};
}

for (size_t i = 0; i < input_tensors.size(); i++) {
auto in = input_tensors[i];
auto dims = input_dims[i];
std::string name = std::string("input_") + std::to_string(ctx->num_inputs);
LOG_INFO(
ctx->logger, "Adding Input " << in->debugName() << " named " << name << " in engine (conversion.AddInputs)");
LOG_DEBUG(ctx->logger, "Input shape set to " << dims.input_shape);
auto trt_in = ctx->net->addInput(name.c_str(), ctx->input_type, dims.input_shape);
ctx->logger,
"Adding Input " << in->debugName() << " named : " << name << ", shape: " << dims.input_shape
<< ", dtype : " << ctx->input_dtypes[i] << " in engine (conversion.AddInputs)");
auto trt_in = ctx->net->addInput(name.c_str(), ctx->input_dtypes[i], dims.input_shape);
TRTORCH_CHECK(trt_in, "Failed to add input node: " << in->debugName() << " (conversion.AddInputs)");

profile->setDimensions(trt_in->getName(), nvinfer1::OptProfileSelector::kMIN, dims.min);
Expand Down
7 changes: 4 additions & 3 deletions core/conversion/conversionctx/ConversionCtx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,26 +61,27 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
case nvinfer1::DataType::kHALF:
TRTORCH_CHECK(builder->platformHasFastFp16(), "Requested inference in FP16 but platform does not support FP16");
cfg->setFlag(nvinfer1::BuilderFlag::kFP16);
input_type = nvinfer1::DataType::kHALF;
break;
case nvinfer1::DataType::kINT8:
TRTORCH_CHECK(builder->platformHasFastInt8(), "Requested inference in INT8 but platform does not support INT8");
cfg->setFlag(nvinfer1::BuilderFlag::kINT8);
if (!settings.strict_types) {
cfg->setFlag(nvinfer1::BuilderFlag::kFP16);
}
input_type = nvinfer1::DataType::kFLOAT;
TRTORCH_CHECK(
settings.calibrator != nullptr,
"Requested inference in INT8 but no calibrator provided, set the ptq_calibrator field in the CompileSpec struct with your calibrator");
cfg->setInt8Calibrator(settings.calibrator);
break;
case nvinfer1::DataType::kFLOAT:
case nvinfer1::DataType::kINT32:
case nvinfer1::DataType::kBOOL:
default:
input_type = nvinfer1::DataType::kFLOAT;
break;
}

op_precision = settings.op_precision;
input_dtypes = settings.input_dtypes;

if (settings.disable_tf32) {
cfg->clearFlag(nvinfer1::BuilderFlag::kTF32);
Expand Down
3 changes: 2 additions & 1 deletion core/conversion/conversionctx/ConversionCtx.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ struct Device {

struct BuilderSettings {
nvinfer1::DataType op_precision = nvinfer1::DataType::kFLOAT;
std::vector<nvinfer1::DataType> input_dtypes;
bool disable_tf32 = false;
bool refit = false;
bool debug = false;
Expand Down Expand Up @@ -57,7 +58,7 @@ struct ConversionCtx {
nvinfer1::IBuilder* builder;
nvinfer1::INetworkDefinition* net;
nvinfer1::IBuilderConfig* cfg;
nvinfer1::DataType input_type;
std::vector<nvinfer1::DataType> input_dtypes;
nvinfer1::DataType op_precision;
BuilderSettings settings;
util::logging::TRTorchLogger logger;
Expand Down
9 changes: 9 additions & 0 deletions cpp/api/include/trtorch/trtorch.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ struct TRTORCH_API CompileSpec {
kHalf,
/// INT8
kChar,
/// INT32
kInt32,
/// Bool
kBool,
};

/**
Expand Down Expand Up @@ -239,6 +243,11 @@ struct TRTORCH_API CompileSpec {
*/
DataType op_precision = DataType::kFloat;

/**
* Data types for input tensors
*/
std::vector<DataType> input_dtypes;

/**
* Prevent Float32 layers from using TF32 data format
*
Expand Down
42 changes: 30 additions & 12 deletions cpp/api/src/compile_spec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,26 @@

namespace trtorch {
CompileSpec::DataType::DataType(c10::ScalarType t) {
TRTORCH_CHECK(t == at::kHalf || t == at::kFloat || t == at::kChar, "Data type is unsupported");
TRTORCH_CHECK(
t == at::kHalf || t == at::kFloat || t == at::kChar || t == at::kInt || t == at::kBool,
"Data type is unsupported");
switch (t) {
case at::kHalf:
value = DataType::kHalf;
break;
case at::kChar:
value = DataType::kChar;
break;
case at::kInt:
value = DataType::kInt32;
break;
case at::kBool:
value = DataType::kBool;
break;
case at::kFloat:
default:
value = DataType::kFloat;
break;
case at::kChar:
value = DataType::kChar;
}
}

Expand Down Expand Up @@ -74,19 +83,28 @@ std::vector<core::ir::InputRange> to_vec_internal_input_ranges(std::vector<Compi
return internal;
}

core::CompileSpec to_internal_compile_spec(CompileSpec external) {
core::CompileSpec internal(to_vec_internal_input_ranges(external.input_ranges));

switch (external.op_precision) {
nvinfer1::DataType toTRTDataType(CompileSpec::DataType value) {
switch (value) {
case CompileSpec::DataType::kChar:
internal.convert_info.engine_settings.op_precision = nvinfer1::DataType::kINT8;
break;
return nvinfer1::DataType::kINT8;
case CompileSpec::DataType::kHalf:
internal.convert_info.engine_settings.op_precision = nvinfer1::DataType::kHALF;
break;
return nvinfer1::DataType::kHALF;
case CompileSpec::DataType::kInt32:
return nvinfer1::DataType::kINT32;
case CompileSpec::DataType::kBool:
return nvinfer1::DataType::kBOOL;
case CompileSpec::DataType::kFloat:
default:
internal.convert_info.engine_settings.op_precision = nvinfer1::DataType::kFLOAT;
return nvinfer1::DataType::kFLOAT;
}
}

core::CompileSpec to_internal_compile_spec(CompileSpec external) {
core::CompileSpec internal(to_vec_internal_input_ranges(external.input_ranges));

internal.convert_info.engine_settings.op_precision = toTRTDataType(external.op_precision);
for (auto dtype : external.input_dtypes) {
internal.convert_info.engine_settings.input_dtypes.push_back(toTRTDataType(dtype));
}

internal.convert_info.engine_settings.disable_tf32 = external.disable_tf32;
Expand Down
25 changes: 25 additions & 0 deletions py/trtorch/_compile_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,24 @@ def _parse_torch_fallback(fallback_info: Dict[str, Any]) -> trtorch._C.TorchFall

return info

def _parse_input_dtypes(input_dtypes: List) -> List:
parsed_input_dtypes = []
for dtype in input_dtypes:
if isinstance(dtype, torch.dtype):
if dtype == torch.int8:
parsed_input_dtypes.append(_types.dtype.int8)
elif dtype == torch.half:
parsed_input_dtypes.append(_types.dtype.half)
elif dtype == torch.float:
parsed_input_dtypes.append(_types.dtype.float)
elif dtype == torch.int32:
parsed_input_dtypes.append(_types.dtype.int32)
elif dtype == torch.bool:
parsed_input_dtypes.append(_types.dtype.bool)
else:
raise TypeError("Invalid input dtype. Supported input datatypes include float|half|int8|int32|bool), got: " + str(dtype))

return parsed_input_dtypes

def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
info = trtorch._C.CompileSpec()
Expand All @@ -153,6 +171,9 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
if "op_precision" in compile_spec:
info.op_precision = _parse_op_precision(compile_spec["op_precision"])

if "input_dtypes" in compile_spec:
info.input_dtypes = _parse_input_dtypes(compile_spec["input_dtypes"])

if "calibrator" in compile_spec:
info.ptq_calibrator = compile_spec["calibrator"]

Expand Down Expand Up @@ -237,6 +258,7 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt.
"allow_gpu_fallback": false, # (DLA only) Allow layers unsupported on DLA to run on GPU
},
"op_precision": torch.half, # Operating precision set to FP16
# List of datatypes that should be configured for each input. Supported options torch.{float|half|int8|int32|bool}.
"disable_tf32": False, # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas
"refit": False, # enable refit
"debug": False, # enable debuggable engine
Expand Down Expand Up @@ -288,6 +310,9 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt.
backend_spec._set_device(d)
backend_spec._set_torch_fallback(torch_fallback)
backend_spec._set_op_precision(int(parsed_spec.op_precision))
for dtype in parsed_spec.input_dtypes:
backend_spec._append_input_dtypes(int64_t(dtype))

backend_spec._set_disable_tf32(parsed_spec.disable_tf32)
backend_spec._set_refit(parsed_spec.refit)
backend_spec._set_debug(parsed_spec.debug)
Expand Down
2 changes: 2 additions & 0 deletions py/trtorch/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def compile(module: torch.jit.ScriptModule, compile_spec: Any) -> torch.jit.Scri
"allow_gpu_fallback": false, # (DLA only) Allow layers unsupported on DLA to run on GPU
},
"op_precision": torch.half, # Operating precision set to FP16
"input_dtypes": [torch.float32] # List of datatypes that should be configured for each input. Supported options torch.{float|half|int8|int32|bool}.
"refit": false, # enable refit
"debug": false, # enable debuggable engine
"strict_types": false, # kernels should strictly run in operating precision
Expand Down Expand Up @@ -106,6 +107,7 @@ def convert_method_to_trt_engine(module: torch.jit.ScriptModule, method_name: st
"allow_gpu_fallback": false, # (DLA only) Allow layers unsupported on DLA to run on GPU
},
"op_precision": torch.half, # Operating precision set to FP16
# List of datatypes that should be configured for each input. Supported options torch.{float|half|int8|int32|bool}.
"disable_tf32": False, # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas
"refit": false, # enable refit
"debug": false, # enable debuggable engine
Expand Down
1 change: 1 addition & 0 deletions py/trtorch/csrc/register_tensorrt_classes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ void RegisterTRTCompileSpec() {
.def("_set_device", &trtorch::pyapi::CompileSpec::setDeviceIntrusive)
.def("_set_torch_fallback", &trtorch::pyapi::CompileSpec::setTorchFallbackIntrusive)
.def("_set_ptq_calibrator", &trtorch::pyapi::CompileSpec::setPTQCalibratorViaHandle)
.def("_append_input_dtypes", &trtorch::pyapi::CompileSpec::appendInputDtypes)
.def("__str__", &trtorch::pyapi::CompileSpec::stringify);

ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, op_precision);
Expand Down
21 changes: 20 additions & 1 deletion py/trtorch/csrc/tensorrt_classes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ std::string to_str(DataType value) {
return "Half";
case DataType::kChar:
return "Int8";
case DataType::kInt32:
return "Int32";
case DataType::kBool:
return "Bool";
case DataType::kFloat:
default:
return "Float";
Expand All @@ -42,6 +46,10 @@ nvinfer1::DataType toTRTDataType(DataType value) {
return nvinfer1::DataType::kINT8;
case DataType::kHalf:
return nvinfer1::DataType::kHALF;
case DataType::kInt32:
return nvinfer1::DataType::kINT32;
case DataType::kBool:
return nvinfer1::DataType::kBOOL;
case DataType::kFloat:
default:
return nvinfer1::DataType::kFLOAT;
Expand Down Expand Up @@ -124,8 +132,15 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() {
for (auto i : input_ranges) {
internal_input_ranges.push_back(i.toInternalInputRange());
}

std::vector<nvinfer1::DataType> trt_input_dtypes;
for (auto dtype : input_dtypes) {
trt_input_dtypes.push_back(toTRTDataType(dtype));
}

auto info = core::CompileSpec(internal_input_ranges);
info.convert_info.engine_settings.op_precision = toTRTDataType(op_precision);
info.convert_info.engine_settings.input_dtypes = trt_input_dtypes;
info.convert_info.engine_settings.calibrator = ptq_calibrator;
info.convert_info.engine_settings.disable_tf32 = disable_tf32;
info.convert_info.engine_settings.refit = refit;
Expand Down Expand Up @@ -159,9 +174,13 @@ std::string CompileSpec::stringify() {
for (auto i : input_ranges) {
ss << i.to_str();
}
std::string enabled = torch_fallback.enabled ? "True" : "False";
ss << " ]" << std::endl;
ss << " \"Op Precision\": " << to_str(op_precision) << std::endl;
ss << " \"Input dtypes\": [" << std::endl;
for (auto i : input_dtypes) {
ss << to_str(i);
}
ss << " ]" << std::endl;
ss << " \"TF32 Disabled\": " << disable_tf32 << std::endl;
ss << " \"Refit\": " << refit << std::endl;
ss << " \"Debug\": " << debug << std::endl;
Expand Down
12 changes: 6 additions & 6 deletions py/trtorch/csrc/tensorrt_classes.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,7 @@ struct InputRange : torch::CustomClassHolder {
std::string to_str();
};

enum class DataType : int8_t {
kFloat,
kHalf,
kChar,
};
enum class DataType : int8_t { kFloat, kHalf, kChar, kInt32, kBool };

std::string to_str(DataType value);
nvinfer1::DataType toTRTDataType(DataType value);
Expand Down Expand Up @@ -108,7 +104,9 @@ struct CompileSpec : torch::CustomClassHolder {
void appendInputRange(const c10::intrusive_ptr<InputRange>& ir) {
input_ranges.push_back(*ir);
}

void appendInputDtypes(int64_t dtype) {
input_dtypes.push_back(static_cast<DataType>(dtype));
}
int64_t getPTQCalibratorHandle() {
return (int64_t)ptq_calibrator;
}
Expand All @@ -120,6 +118,7 @@ struct CompileSpec : torch::CustomClassHolder {
void setTorchFallbackIntrusive(const c10::intrusive_ptr<TorchFallback>& fb) {
torch_fallback = *fb;
}

void setPTQCalibratorViaHandle(int64_t handle) {
ptq_calibrator = (nvinfer1::IInt8Calibrator*)handle;
}
Expand All @@ -142,6 +141,7 @@ struct CompileSpec : torch::CustomClassHolder {
std::vector<InputRange> input_ranges;
nvinfer1::IInt8Calibrator* ptq_calibrator = nullptr;
DataType op_precision = DataType::kFloat;
std::vector<DataType> input_dtypes;
bool disable_tf32 = false;
bool refit = false;
bool debug = false;
Expand Down
3 changes: 3 additions & 0 deletions py/trtorch/csrc/trtorch_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ PYBIND11_MODULE(_C, m) {
.value("half", DataType::kHalf, "16 bit floating point number")
.value("float16", DataType::kHalf, "16 bit floating point number")
.value("int8", DataType::kChar, "8 bit integer number")
.value("int32", DataType::kInt32, "32 bit integer number")
.value("bool", DataType::kChar, "Boolean value")
.export_values();

py::enum_<DeviceType>(m, "DeviceType", "Enum to specify device kinds to build TensorRT engines for")
Expand Down Expand Up @@ -242,6 +244,7 @@ PYBIND11_MODULE(_C, m) {
.def("_get_calibrator_handle", &CompileSpec::getPTQCalibratorHandle, "[Internal] gets a handle from a calibrator")
.def_readwrite("input_ranges", &CompileSpec::input_ranges)
.def_readwrite("op_precision", &CompileSpec::op_precision)
.def_readwrite("input_dtypes", &CompileSpec::input_dtypes)
.def_readwrite("ptq_calibrator", &CompileSpec::ptq_calibrator)
.def_readwrite("refit", &CompileSpec::refit)
.def_readwrite("disable_tf32", &CompileSpec::disable_tf32)
Expand Down

0 comments on commit a3f4a3c

Please sign in to comment.