Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Properly cast intermediate Int8 tensors to TensorRT Engines in Fallback #1549

Merged
merged 2 commits into from
Dec 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions core/partitioning/partitioninginfo/PartitioningInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ struct PartitioningInfo {
std::vector<std::string> forced_fallback_operators;
bool truncate_long_and_double;
ir::Device target_device;
bool cast_int8_inputs = false;

std::string getGPUDeviceString() const {
return "cuda:" + std::to_string(target_device.gpu_id);
Expand Down
68 changes: 56 additions & 12 deletions core/partitioning/shape_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,18 +99,24 @@ torch::jit::Node* getUpstreamCastNode(torch::jit::Value* val) {
return nullptr;
}

torch::jit::Node* createCastNode(SegmentedBlock& seg_block, size_t index, bool is_input, std::string device) {
torch::jit::Node* createCastNode(
SegmentedBlock& seg_block,
size_t index,
bool is_input,
at::ScalarType dtype,
std::string device,
bool force_create_node = false) {
auto cast_raw_value = is_input ? seg_block.raw_inputs()[index] : seg_block.raw_outputs()[index];
auto cast_subgraph_value = is_input ? seg_block.inputs()[index] : seg_block.outputs()[index];
torch::jit::Node* cast_node = getUpstreamCastNode(cast_raw_value);
auto g = seg_block.g();
// if we can find upstream aten::to node, we use it's parameters for creating new cast node
if (cast_node) {
if (cast_node && !force_create_node) {
std::unordered_map<torch::jit::Value*, torch::jit::Value*> value_map;
value_map.insert({cast_node->inputs()[0], cast_subgraph_value});
if (!is_input) {
// if this value is output, we need to cast it to int32
auto const_val = g->insertConstant(3);
auto const_val = g->insertConstant(dtype);
if (cast_node->inputs()[1]->node()->output()->type()->kind() == torch::jit::TypeKind::DeviceObjType) {
value_map.insert({cast_node->inputs()[2], const_val});
} else {
Expand All @@ -122,7 +128,7 @@ torch::jit::Node* createCastNode(SegmentedBlock& seg_block, size_t index, bool i
// auto cast_node = g->prependNode(g->createClone(cast_node, env));
} else {
// if there is no explicit cast aten::to operation, we need to create a node
auto const_type = is_input ? g->insertConstant(4) : g->insertConstant(3);
auto const_type = g->insertConstant(dtype);
auto const_zero = g->insertConstant(0);
const_zero->setType(torch::jit::BoolType::get());
auto cuda = g->insertConstant(device);
Expand Down Expand Up @@ -222,27 +228,56 @@ void getSegmentsOutputByRunning(

auto target_device = partitioning_info.getGPUDeviceString();

// auto int64 <=> int32 conversion
if (seg_block.target() == SegmentedBlock::kTorch && partitioning_info.truncate_long_and_double) {
// auto int64 <=> int32 conversion + int8 <=> int32 conversion for non-quantized models
if (seg_block.target() == SegmentedBlock::kTorch) {
// First, check if there is Int64 input
for (size_t i = 0; i < seg_block.inputs().size(); ++i) {
if (ivalues_maps[seg_block.raw_inputs()[i]].isTensor()) {
auto cur_ivalue = ivalues_maps[seg_block.raw_inputs()[i]];
at::ScalarType t = cur_ivalue.toTensor().scalar_type();
if (t == at::kLong) {
if (t == at::kLong && partitioning_info.truncate_long_and_double) {
LOG_DEBUG(
"Detected graph Long tensor input type during shape analysis, "
<< "inserting aten::to cast to Long to ensure this Torch block receives "
<< "a Long-type tensor input.");
// we add a cast operation to cast the type to Int64
auto cast_node = createCastNode(seg_block, i, true, target_device);
auto cast_node = createCastNode(seg_block, i, true, at::kLong, target_device);
seg_block.g()->prependNode(cast_node);
seg_block.inputs()[i]->replaceAllUsesAfterNodeWith(cast_node, cast_node->outputs()[0]);
} else if (t == at::kByte && partitioning_info.cast_int8_inputs) {
LOG_DEBUG(
"Detected graph Byte tensor input type during shape analysis, "
<< "inserting aten::to cast to Byte to ensure this Torch block receives "
<< "a Byte-type tensor input.");
// If the input has type Byte, ensure it is casted to the correct type
auto cast_node = createCastNode(seg_block, i, true, at::kByte, target_device, /*force_create_node=*/true);
seg_block.g()->prependNode(cast_node);
seg_block.inputs()[i]->replaceAllUsesAfterNodeWith(cast_node, cast_node->outputs()[0]);
}
}
}

for (size_t i = 0; i < seg_block.outputs().size(); ++i) {
if (ivalues_maps[seg_block.raw_outputs()[i]].isTensor()) {
auto cur_ivalue = ivalues_maps[seg_block.raw_outputs()[i]];
at::ScalarType t = cur_ivalue.toTensor().scalar_type();
if (t == at::kLong) {
auto cast_node = createCastNode(seg_block, i, false, target_device);

// If the output has type Long and truncation was requested, insert truncate
if (t == at::kLong && partitioning_info.truncate_long_and_double) {
LOG_DEBUG(
"Detected graph Long tensor output type during shape analysis, "
<< "inserting aten::to cast to Int to ensure the subsequent TensorRT block "
<< "receives an Int-type tensor input.");
auto cast_node = createCastNode(seg_block, i, false, at::kInt, target_device);
seg_block.g()->appendNode(cast_node);
seg_block.g()->block()->replaceOutput(i, cast_node->outputs()[0]);
} else if (t == at::kByte && partitioning_info.cast_int8_inputs) {
LOG_DEBUG(
"Detected graph Byte tensor output type during shape analysis, "
<< "inserting aten::to cast to Int to ensure the subsequent TensorRT block "
<< "receives an Int-type tensor input.");
// If the output has type Byte and casting was requested, insert Integer cast
auto cast_node = createCastNode(seg_block, i, false, at::kInt, target_device, /*force_create_node=*/true);
seg_block.g()->appendNode(cast_node);
seg_block.g()->block()->replaceOutput(i, cast_node->outputs()[0]);
}
Expand All @@ -254,11 +289,13 @@ void getSegmentsOutputByRunning(
std::vector<std::vector<int64_t>> input_shapes;
std::vector<at::ScalarType> input_types;
for (size_t i = 0; i < seg_block.inputs().size(); ++i) {
if (ivalues_maps[seg_block.raw_inputs()[i]].isTensor()) {
auto current_input = seg_block.raw_inputs()[i];

if (ivalues_maps[current_input].isTensor()) {
// set the input_shape and data_type
// we can use a temp value here instead of replacing the values in ivalues_map since we only use ivalues_map for
// shape inference
auto cur_ivalue = ivalues_maps[seg_block.raw_inputs()[i]];
auto cur_ivalue = ivalues_maps[current_input];
at::ScalarType t = cur_ivalue.toTensor().scalar_type();

if (!partitioning_info.truncate_long_and_double && (t == at::kLong || t == at::kDouble)) {
Expand All @@ -271,10 +308,16 @@ void getSegmentsOutputByRunning(
cur_ivalue = cur_ivalue.toTensor().to(at::kFloat);
LOG_WARNING("Truncating graph input type from at::kDouble to at::kFloat");
}

c10::optional<nvinfer1::DataType> dtype = util::optTypeMetaToTRTDataType(cur_ivalue.toTensor().dtype());
if (dtype == c10::nullopt) {
TORCHTRT_THROW_ERROR("Unsupported input data type " << cur_ivalue.toTensor().dtype());
} else if (dtype && dtype.value() == nvinfer1::DataType::kINT8 && partitioning_info.cast_int8_inputs) {
// Special case to ensure input IValues to TensorRT engine are not Int8 type if the
// model itself is not quantized
cur_ivalue = cur_ivalue.toTensor().to(at::kInt);
}

if (cur_ivalue.toTensor().sizes().size() == 0) {
// handle Scalar types, which has sizes of []
input_shapes.push_back(util::toVec(util::toDims(c10::List<int64_t>({1}))));
Expand All @@ -297,6 +340,7 @@ void runShapeAnalysis(
const ir::ShapeMode& shape_mode) {
// register every segment's input shape, and it's running output IValues
for (auto& seg_block : ctx->partitioned_blocks[block]) {
LOG_GRAPH("Running shape analysis on block " << seg_block);
torch::jit::ConstantPooling(seg_block.g());
getSegmentsOutputByRunning(seg_block, example_tensor_map, ctx->settings, shape_mode);
}
Expand Down
1 change: 1 addition & 0 deletions core/util/trt_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ const std::unordered_map<at::ScalarType, nvinfer1::DataType>& get_at_trt_type_ma
{at::kHalf, nvinfer1::DataType::kHALF},
{at::kInt, nvinfer1::DataType::kINT32},
{at::kChar, nvinfer1::DataType::kINT8},
{at::kByte, nvinfer1::DataType::kINT8},
{at::kBool, nvinfer1::DataType::kBOOL}};
return at_trt_type_map;
}
Expand Down
3 changes: 3 additions & 0 deletions cpp/src/compile_spec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,11 @@ torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external) {
internal.convert_info.engine_settings.dla_local_dram_size = external.dla_local_dram_size;
internal.convert_info.engine_settings.dla_global_dram_size = external.dla_global_dram_size;

internal.partitioning_info.cast_int8_inputs = true;

if (internal.convert_info.engine_settings.enabled_precisions.find(nvinfer1::DataType::kINT8) !=
internal.convert_info.engine_settings.enabled_precisions.end()) {
internal.partitioning_info.cast_int8_inputs = false;
if (external.ptq_calibrator) {
internal.convert_info.engine_settings.calibrator = external.ptq_calibrator;
} else {
Expand Down
17 changes: 17 additions & 0 deletions py/torch_tensorrt/csrc/tensorrt_classes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,11 +300,15 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() {
info.convert_info.engine_settings.enabled_precisions.insert(toTRTDataType(p));
}

info.partitioning_info.cast_int8_inputs = true;

if (ptq_calibrator) {
info.convert_info.engine_settings.calibrator = ptq_calibrator;
info.partitioning_info.cast_int8_inputs = false;
} else {
if (info.convert_info.engine_settings.enabled_precisions.find(nvinfer1::DataType::kINT8) !=
info.convert_info.engine_settings.enabled_precisions.end()) {
info.partitioning_info.cast_int8_inputs = false;
info.lower_info.unfreeze_module = true;
info.lower_info.disable_cse = true;
}
Expand All @@ -313,10 +317,23 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() {
info.convert_info.engine_settings.disable_tf32 = disable_tf32;
info.convert_info.engine_settings.refit = refit;
info.convert_info.engine_settings.debug = debug;

// Specify + replicate device settings for phases requiring it
info.convert_info.engine_settings.device.device_type = toTRTDeviceType(device.device_type);
info.convert_info.engine_settings.device.gpu_id = device.gpu_id;
info.convert_info.engine_settings.device.dla_core = device.dla_core;
info.convert_info.engine_settings.device.allow_gpu_fallback = device.allow_gpu_fallback;

info.lower_info.target_device.device_type = toTRTDeviceType(device.device_type);
info.lower_info.target_device.gpu_id = device.gpu_id;
info.lower_info.target_device.dla_core = device.dla_core;
info.lower_info.target_device.allow_gpu_fallback = device.allow_gpu_fallback;

info.partitioning_info.target_device.device_type = toTRTDeviceType(device.device_type);
info.partitioning_info.target_device.gpu_id = device.gpu_id;
info.partitioning_info.target_device.dla_core = device.dla_core;
info.partitioning_info.target_device.allow_gpu_fallback = device.allow_gpu_fallback;

info.partitioning_info.enabled = torch_fallback.enabled;
info.partitioning_info.min_block_size = torch_fallback.min_block_size;
info.partitioning_info.forced_fallback_operators = torch_fallback.forced_fallback_operators;
Expand Down
60 changes: 60 additions & 0 deletions tests/core/partitioning/test_type_auto_conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,63 @@ TEST(Partitioning, ImplicitAutoConversionCorrectly) {
}
ASSERT_TRUE(checkInsertedCastNodeNumber(segmented_blocks[1], 2));
}

TEST(Partitioning, ExplicitNodeAutoInt8ConversionCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor,
%y.1 : Tensor):

%26 : int = prim::Constant[value=1]()
%21 : bool = prim::Constant[value=0]()
%60 : Device = prim::Constant[value="cuda"]()
%14 : NoneType = prim::Constant()
%3 : int = prim::Constant[value=5]()
%19 : int = prim::Constant[value=0]()
%29 : int = prim::Constant[value=2]()
%13 : int[] = prim::ListConstruct(%3, %3)
%k_.1 : Tensor = aten::ones(%13, %19, %14, %60, %14)
%20 : int[] = prim::ListConstruct(%19)
%k.1 : Tensor = aten::sum(%k_.1, %20, %21, %14)
%x.5 : Tensor = aten::add_(%x.1, %y.1, %26)
%31 : Tensor = aten::mul(%y.1, %29)
%x.9 : Tensor = aten::add_(%x.5, %31, %26)
%x.13 : Tensor = aten::add_(%x.9, %k.1, %26)
%x.17 : Tensor = aten::sub_(%x.13, %k.1, %26)
%x.21 : Tensor = aten::add_(%x.17, %k.1, %26)
%x.25 : Tensor = aten::sub_(%x.21, %k.1, %26)

return (%x.25))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get(), true);

torch_tensorrt::core::partitioning::PartitioningInfo partitioning_info;
partitioning_info.enabled = true;
partitioning_info.cast_int8_inputs = true;
partitioning_info.forced_fallback_operators = {"aten::ones"};
partitioning_info.truncate_long_and_double = true;
std::vector<torch_tensorrt::core::ir::Input> inputs;
inputs.push_back(torch_tensorrt::core::ir::Input({5, 5}));
inputs.push_back(torch_tensorrt::core::ir::Input({5, 5}));

std::unordered_map<const torch::jit::Value*, std::vector<torch_tensorrt::core::ir::Input>> inputs_map;
std::unordered_map<const torch::jit::Value*, std::vector<c10::optional<at::ScalarType>>> input_types;
inputs_map.insert({g->inputs()[0], {inputs[0]}});
input_types.insert({g->inputs()[0], {{at::kFloat}}});
inputs_map.insert({g->inputs()[1], {inputs[1]}});
input_types.insert({g->inputs()[1], {{at::kInt}}});

partitioning_info.collection_input_spec_map = inputs_map;
torch_tensorrt::core::partitioning::PartitioningCtx ctx(g->block(), partitioning_info);
ctx.input_types_map = input_types;
torch_tensorrt::core::partitioning::populateInputIValues(&ctx);
torch_tensorrt::core::partitioning::partition(&ctx);
auto segmented_blocks = ctx.partitioned_blocks.begin()->second;

for (auto& seg_block : segmented_blocks) {
LOG_DEBUG(seg_block << " cur seg block");
}

// Seeking 1 inserted aten::to converting Byte to Int (%k_.1 is a Byte Tensor)
ASSERT_TRUE(checkInsertedCastNodeNumber(segmented_blocks[0], 1));
}