Skip to content

Commit 536983b

Browse files
committed
feat(disable_tf32): Add a new API to disable TF32
Signed-off-by: Naren Dasan <naren@narendasan.com> Signed-off-by: Naren Dasan <narens@nvidia.com>
1 parent 1660633 commit 536983b

File tree

13 files changed

+44
-2
lines changed

13 files changed

+44
-2
lines changed

core/conversion/conversionctx/ConversionCtx.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ namespace conversion {
1212
std::ostream& operator<<(std::ostream& os, const BuilderSettings& s) {
1313
os << "Settings requested for TensorRT engine:" \
1414
<< "\n Operating Precision: " << s.op_precision \
15+
<< "\n TF32 Floating Point Computation Enabled: " << !s.disable_tf32 \
1516
<< "\n Make Refittable Engine: " << s.refit \
1617
<< "\n Debuggable Engine: " << s.debug \
1718
<< "\n Strict Types: " << s.strict_types \
@@ -77,6 +78,10 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
7778
}
7879
op_precision = settings.op_precision;
7980

81+
if (settings.disable_tf32) {
82+
cfg->clearFlag(nvinfer1::BuilderFlag::kTF32);
83+
}
84+
8085
if (settings.refit) {
8186
cfg->setFlag(nvinfer1::BuilderFlag::kREFIT);
8287
}

core/conversion/conversionctx/ConversionCtx.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ struct Device {
2424

2525
struct BuilderSettings {
2626
nvinfer1::DataType op_precision = nvinfer1::DataType::kFLOAT;
27+
bool disable_tf32 = false;
2728
bool refit = false;
2829
bool debug = false;
2930
bool strict_types = false;

cpp/api/include/trtorch/trtorch.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,15 @@ struct TRTORCH_API CompileSpec {
239239
*/
240240
DataType op_precision = DataType::kFloat;
241241

242+
/**
243+
* Prevent Float32 layers from using TF32 data format
244+
*
245+
* TF32 computes inner products by rounding the inputs to 10-bit mantissas
246+
* before multiplying, but accumulates the sum using 23-bit mantissas.
247+
* This is the behavior of FP32 layers by default.
248+
*/
249+
bool disable_tf32 = false;
250+
242251
/**
243252
* Build a refitable engine
244253
*/

cpp/api/src/compile_spec.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ core::CompileSpec to_internal_compile_spec(CompileSpec external) {
8989
internal.convert_info.engine_settings.op_precision = nvinfer1::DataType::kFLOAT;
9090
}
9191

92+
internal.convert_info.engine_settings.disable_tf32 = external.disable_tf32;
9293
internal.convert_info.engine_settings.refit = external.refit;
9394
internal.convert_info.engine_settings.debug = external.debug;
9495
internal.convert_info.engine_settings.strict_types = external.strict_types;

cpp/trtorchc/main.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,12 @@ int main(int argc, char** argv) {
163163
"(Only used when targeting DLA (device-type)) Lets engine run layers on GPU if they are not supported on DLA",
164164
{"allow-gpu-fallback"});
165165

166+
args::Flag disable_tf32(
167+
parser,
168+
"disable-tf32",
169+
"Prevent Float32 layers from using the TF32 data format",
170+
{"disable-tf32"});
171+
166172
args::ValueFlag<std::string> op_precision(
167173
parser,
168174
"precision",
@@ -263,6 +269,10 @@ int main(int argc, char** argv) {
263269
compile_settings.device.allow_gpu_fallback = true;
264270
}
265271

272+
if (disable_tf32) {
273+
compile_settings.disable_tf32 = true;
274+
}
275+
266276
std::string calibration_cache_file_path = "";
267277
if (calibration_cache_file) {
268278
calibration_cache_file_path = resolve_path(args::get(calibration_cache_file));

py/trtorch/_compile_spec.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,10 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
135135
if "op_precision" in compile_spec:
136136
info.op_precision = _parse_op_precision(compile_spec["op_precision"])
137137

138+
if "disable_tf32" in compile_spec:
139+
assert isinstance(compile_spec["disable_tf32"], bool)
140+
info.disable_tf32 = compile_spec["disable_tf32"]
141+
138142
if "refit" in compile_spec:
139143
assert isinstance(compile_spec["refit"], bool)
140144
info.refit = compile_spec["refit"]
@@ -201,6 +205,7 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt.
201205
"allow_gpu_fallback": false, # (DLA only) Allow layers unsupported on DLA to run on GPU
202206
},
203207
"op_precision": torch.half, # Operating precision set to FP16
208+
"disable_tf32": False, # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas
204209
"refit": False, # enable refit
205210
"debug": False, # enable debuggable engine
206211
"strict_types": False, # kernels should strictly run in operating precision
@@ -239,6 +244,7 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt.
239244

240245
backend_spec.set_device(d)
241246
backend_spec.set_op_precision(int(parsed_spec.op_precision))
247+
backend_spec.set_disable_tf32(parsed_spec.disable_tf32)
242248
backend_spec.set_refit(parsed_spec.refit)
243249
backend_spec.set_debug(parsed_spec.debug)
244250
backend_spec.set_refit(parsed_spec.refit)

py/trtorch/_compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def convert_method_to_trt_engine(module: torch.jit.ScriptModule, method_name: st
9999
"allow_gpu_fallback": false, # (DLA only) Allow layers unsupported on DLA to run on GPU
100100
},
101101
"op_precision": torch.half, # Operating precision set to FP16
102+
"disable_tf32": False, # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas
102103
"refit": false, # enable refit
103104
"debug": false, # enable debuggable engine
104105
"strict_types": false, # kernels should strictly run in operating precision

py/trtorch/csrc/register_tensorrt_classes.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ void RegisterTRTCompileSpec() {
3232
.def("__str__", &trtorch::pyapi::CompileSpec::stringify);
3333

3434
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, op_precision);
35+
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, disable_tf32);
3536
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, refit);
3637
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, debug);
3738
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, strict_types);

py/trtorch/csrc/tensorrt_classes.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() {
9999
}
100100
auto info = core::CompileSpec(internal_input_ranges);
101101
info.convert_info.engine_settings.op_precision = toTRTDataType(op_precision);
102+
info.convert_info.engine_settings.disable_tf32 = disable_tf32;
102103
info.convert_info.engine_settings.refit = refit;
103104
info.convert_info.engine_settings.debug = debug;
104105
info.convert_info.engine_settings.strict_types = strict_types;
@@ -128,6 +129,7 @@ std::string CompileSpec::stringify() {
128129
}
129130
ss << " ]" << std::endl;
130131
ss << " \"Op Precision\": " << to_str(op_precision) << std::endl;
132+
ss << " \"TF32 Disabled\": " << disable_tf32 << std::endl;
131133
ss << " \"Refit\": " << refit << std::endl;
132134
ss << " \"Debug\": " << debug << std::endl;
133135
ss << " \"Strict Types\": " << strict_types << std::endl;

py/trtorch/csrc/tensorrt_classes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ struct CompileSpec : torch::CustomClassHolder {
9999
}
100100

101101
ADD_ENUM_GET_SET(op_precision, DataType, static_cast<int64_t>(DataType::kChar));
102+
ADD_FIELD_GET_SET(disable_tf32, bool);
102103
ADD_FIELD_GET_SET(refit, bool);
103104
ADD_FIELD_GET_SET(debug, bool);
104105
ADD_FIELD_GET_SET(strict_types, bool);
@@ -111,6 +112,7 @@ struct CompileSpec : torch::CustomClassHolder {
111112

112113
std::vector<InputRange> input_ranges;
113114
DataType op_precision = DataType::kFloat;
115+
bool disable_tf32 = false;
114116
bool refit = false;
115117
bool debug = false;
116118
bool strict_types = false;

0 commit comments

Comments
 (0)