File tree Expand file tree Collapse file tree 4 files changed +18
-3
lines changed Expand file tree Collapse file tree 4 files changed +18
-3
lines changed Original file line number Diff line number Diff line change @@ -28,6 +28,7 @@ struct BuilderSettings {
2828 bool refit = false ;
2929 bool debug = false ;
3030 bool strict_types = false ;
31+ bool truncate_long_and_double = false ;
3132 Device device;
3233 nvinfer1::EngineCapability capability = nvinfer1::EngineCapability::kDEFAULT ;
3334 nvinfer1::IInt8Calibrator* calibrator = nullptr ;
Original file line number Diff line number Diff line change @@ -94,10 +94,18 @@ nvinfer1::ITensor* Var::ITensorOrFreeze(ConversionCtx* ctx) {
9494 " Requested either IValue containing a Tensor, or ITensor, however Var type is " << type_name ());
9595
9696 nvinfer1::ITensor* out;
97-
97+ auto weights = converters::Weights ();
9898 if (isIValue ()) {
99- auto weights = converters::Weights (ctx, ptr_.ivalue ->toTensor ());
100-
99+ auto tensor = ptr_.ivalue ->toTensor ();
100+ if (tensor.scalar_type () == at::kLong && ctx->settings .truncate_long_and_double ) {
101+ weights = converters::Weights (ctx, tensor.toType (at::kInt ));
102+ LOG_WARNING (" Truncate kLong to kInt for IValue" );
103+ } else if (tensor.scalar_type () == at::kDouble && ctx->settings .truncate_long_and_double ) {
104+ weights = converters::Weights (ctx, tensor.toType (at::kFloat ));
105+ LOG_WARNING (" Truncate kDouble to kFloat for IValue" );
106+ } else {
107+ weights = converters::Weights (ctx, tensor);
108+ }
101109 auto const_layer = ctx->net ->addConstant (weights.shape , weights.data );
102110 TRTORCH_CHECK (const_layer, " Unable to freeze tensor into constant layer" );
103111
Original file line number Diff line number Diff line change @@ -258,6 +258,11 @@ struct TRTORCH_API CompileSpec {
258258 */
259259 bool debug = false ;
260260
261+ /* *
262+ * Truncate long/double type to int/float type
263+ */
264+ bool truncate_long_and_double = false ;
265+
261266 /* *
262267 * Restrict operating type to only set default operation precision
263268 * (op_precision)
Original file line number Diff line number Diff line change @@ -92,6 +92,7 @@ core::CompileSpec to_internal_compile_spec(CompileSpec external) {
9292 internal.convert_info .engine_settings .disable_tf32 = external.disable_tf32 ;
9393 internal.convert_info .engine_settings .refit = external.refit ;
9494 internal.convert_info .engine_settings .debug = external.debug ;
95+ internal.convert_info .engine_settings .truncate_long_and_double = external.truncate_long_and_double ;
9596 internal.convert_info .engine_settings .strict_types = external.strict_types ;
9697 internal.convert_info .engine_settings .device .allow_gpu_fallback = external.device .allow_gpu_fallback ;
9798 internal.convert_info .engine_settings .max_batch_size = external.max_batch_size ;
You can’t perform that action at this time.
0 commit comments