Skip to content

Commit 740eb54

Browse files
committed
feat: support truncate long/double to int/float with option
Signed-off-by: inocsin <vcheungyi@163.com>
1 parent 5b6bd4c commit 740eb54

File tree

4 files changed

+18
-3
lines changed

4 files changed

+18
-3
lines changed

core/conversion/conversionctx/ConversionCtx.h

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ struct BuilderSettings {
2828
bool refit = false;
2929
bool debug = false;
3030
bool strict_types = false;
31+
bool truncate_long_and_double = false;
3132
Device device;
3233
nvinfer1::EngineCapability capability = nvinfer1::EngineCapability::kDEFAULT;
3334
nvinfer1::IInt8Calibrator* calibrator = nullptr;

core/conversion/var/Var.cpp

+11-3
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,18 @@ nvinfer1::ITensor* Var::ITensorOrFreeze(ConversionCtx* ctx) {
9494
"Requested either IValue containing a Tensor, or ITensor, however Var type is " << type_name());
9595

9696
nvinfer1::ITensor* out;
97-
97+
auto weights = converters::Weights();
9898
if (isIValue()) {
99-
auto weights = converters::Weights(ctx, ptr_.ivalue->toTensor());
100-
99+
auto tensor = ptr_.ivalue->toTensor();
100+
if (tensor.scalar_type() == at::kLong && ctx->settings.truncate_long_and_double) {
101+
weights = converters::Weights(ctx, tensor.toType(at::kInt));
102+
LOG_WARNING("Truncate kLong to kInt for IValue");
103+
} else if (tensor.scalar_type() == at::kDouble && ctx->settings.truncate_long_and_double) {
104+
weights = converters::Weights(ctx, tensor.toType(at::kFloat));
105+
LOG_WARNING("Truncate kDouble to kFloat for IValue");
106+
} else {
107+
weights = converters::Weights(ctx, tensor);
108+
}
101109
auto const_layer = ctx->net->addConstant(weights.shape, weights.data);
102110
TRTORCH_CHECK(const_layer, "Unable to freeze tensor into constant layer");
103111

cpp/api/include/trtorch/trtorch.h

+5
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,11 @@ struct TRTORCH_API CompileSpec {
258258
*/
259259
bool debug = false;
260260

261+
/**
262+
* Truncate long/double type to int/float type
263+
*/
264+
bool truncate_long_and_double = false;
265+
261266
/**
262267
* Restrict operating type to only set default operation precision
263268
* (op_precision)

cpp/api/src/compile_spec.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ core::CompileSpec to_internal_compile_spec(CompileSpec external) {
9292
internal.convert_info.engine_settings.disable_tf32 = external.disable_tf32;
9393
internal.convert_info.engine_settings.refit = external.refit;
9494
internal.convert_info.engine_settings.debug = external.debug;
95+
internal.convert_info.engine_settings.truncate_long_and_double = external.truncate_long_and_double;
9596
internal.convert_info.engine_settings.strict_types = external.strict_types;
9697
internal.convert_info.engine_settings.device.allow_gpu_fallback = external.device.allow_gpu_fallback;
9798
internal.convert_info.engine_settings.max_batch_size = external.max_batch_size;

0 commit comments

Comments
 (0)