File tree 4 files changed +18
-3
lines changed
4 files changed +18
-3
lines changed Original file line number Diff line number Diff line change @@ -28,6 +28,7 @@ struct BuilderSettings {
28
28
bool refit = false ;
29
29
bool debug = false ;
30
30
bool strict_types = false ;
31
+ bool truncate_long_and_double = false ;
31
32
Device device;
32
33
nvinfer1::EngineCapability capability = nvinfer1::EngineCapability::kDEFAULT ;
33
34
nvinfer1::IInt8Calibrator* calibrator = nullptr ;
Original file line number Diff line number Diff line change @@ -94,10 +94,18 @@ nvinfer1::ITensor* Var::ITensorOrFreeze(ConversionCtx* ctx) {
94
94
" Requested either IValue containing a Tensor, or ITensor, however Var type is " << type_name ());
95
95
96
96
nvinfer1::ITensor* out;
97
-
97
+ auto weights = converters::Weights ();
98
98
if (isIValue ()) {
99
- auto weights = converters::Weights (ctx, ptr_.ivalue ->toTensor ());
100
-
99
+ auto tensor = ptr_.ivalue ->toTensor ();
100
+ if (tensor.scalar_type () == at::kLong && ctx->settings .truncate_long_and_double ) {
101
+ weights = converters::Weights (ctx, tensor.toType (at::kInt ));
102
+ LOG_WARNING (" Truncate kLong to kInt for IValue" );
103
+ } else if (tensor.scalar_type () == at::kDouble && ctx->settings .truncate_long_and_double ) {
104
+ weights = converters::Weights (ctx, tensor.toType (at::kFloat ));
105
+ LOG_WARNING (" Truncate kDouble to kFloat for IValue" );
106
+ } else {
107
+ weights = converters::Weights (ctx, tensor);
108
+ }
101
109
auto const_layer = ctx->net ->addConstant (weights.shape , weights.data );
102
110
TRTORCH_CHECK (const_layer, " Unable to freeze tensor into constant layer" );
103
111
Original file line number Diff line number Diff line change @@ -258,6 +258,11 @@ struct TRTORCH_API CompileSpec {
258
258
*/
259
259
bool debug = false ;
260
260
261
+ /* *
262
+ * Truncate long/double type to int/float type
263
+ */
264
+ bool truncate_long_and_double = false ;
265
+
261
266
/* *
262
267
* Restrict operating type to only set default operation precision
263
268
* (op_precision)
Original file line number Diff line number Diff line change @@ -92,6 +92,7 @@ core::CompileSpec to_internal_compile_spec(CompileSpec external) {
92
92
internal.convert_info .engine_settings .disable_tf32 = external.disable_tf32 ;
93
93
internal.convert_info .engine_settings .refit = external.refit ;
94
94
internal.convert_info .engine_settings .debug = external.debug ;
95
+ internal.convert_info .engine_settings .truncate_long_and_double = external.truncate_long_and_double ;
95
96
internal.convert_info .engine_settings .strict_types = external.strict_types ;
96
97
internal.convert_info .engine_settings .device .allow_gpu_fallback = external.device .allow_gpu_fallback ;
97
98
internal.convert_info .engine_settings .max_batch_size = external.max_batch_size ;
You can’t perform that action at this time.
0 commit comments