Skip to content

fix: Add lowering pass to remove output repacking in convert_method_to_trt_engine calls #1945

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions core/lowering/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, std::vector<torch::jit::I
passes::RemoveSingleUse0DTensors(g);
passes::RemoveUnnecessaryCasts(g);
passes::ReplaceAtenInt(g);
if (lower_info.converting_to_trt_engine) {
passes::RemoveCollectionCast(g);
}
passes::UnpackAndCastMaskedFill(g, lower_info.getGPUDeviceString());
passes::UnpackAndCastNumToTensor(g, lower_info.getGPUDeviceString());
passes::UnpackAndCastFull(g, lower_info.getGPUDeviceString());
Expand Down
4 changes: 4 additions & 0 deletions core/lowering/lowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ struct LowerInfo {
// Since these QDQ nodes will be identical as they share same input, one of them is eliminated due to CSE lowering
// pass. Disable this in order to not disturb TensorRT's QAT optimizations.
bool disable_cse = false;

// Whether the originating caller is `convert_method_to_trt_engine` (true) or `compile` (false)
bool converting_to_trt_engine = false;

ir::Device target_device;
std::vector<std::string> forced_fallback_modules;
friend std::ostream& operator<<(std::ostream& os, const LowerInfo& l);
Expand Down
1 change: 1 addition & 0 deletions core/lowering/passes/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ void RemoveNOPs(std::shared_ptr<torch::jit::Graph> graph);
void RemoveSingleUse0DTensors(std::shared_ptr<torch::jit::Graph>& g);
void RemoveUnnecessaryCasts(std::shared_ptr<torch::jit::Graph>& graph);
void ReplaceAtenInt(std::shared_ptr<torch::jit::Graph>& g);
void RemoveCollectionCast(std::shared_ptr<torch::jit::Graph>& g);
void UnpackAddMM(std::shared_ptr<torch::jit::Graph>& graph);
void UnpackBatchNorm(std::shared_ptr<torch::jit::Graph>& graph);
void UnpackLogSoftmax(std::shared_ptr<torch::jit::Graph>& graph);
Expand Down
39 changes: 39 additions & 0 deletions core/lowering/passes/remove_unnecessary_casts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,45 @@ void ReplaceAtenInt(std::shared_ptr<torch::jit::Graph>& g) {
LOG_GRAPH("Post removing aten.Int.Tensor operations: " << *g);
}

void RemoveCollectionCast(std::shared_ptr<torch::jit::Graph>& g) {
// Removes unnecessary collection-casting of graph outputs
// Only to be used if the overall output is intended to be a TRT Engine
// Will cause errors if used directly as a TorchScript graph

// Validate the output is a single value with type Tuple or List
if (!(g->outputs().size() == 1 &&
(g->outputs()[0]->node()->kind() == torch::jit::prim::TupleConstruct ||
g->outputs()[0]->node()->kind() == torch::jit::prim::ListConstruct))) {
return;
}

// Ensure all inputs to the Tuple/List Construct operator are regular Tensors
// (nested structures cannot be preserved in TensorRT)
auto all_tensors = true;
auto collection_inputs = g->outputs()[0]->node()->inputs();

for (size_t i = 0; i < collection_inputs.size(); ++i) {
all_tensors &= collection_inputs[i]->type()->isSubtypeOf(c10::TensorType::get());
}

if (!all_tensors) {
return;
}

// For each input to the collection packing operator, add its value directly
// as an output of the graph
for (size_t i = 0; i < collection_inputs.size(); ++i) {
g->registerOutput(collection_inputs[i]);
}

// Remove the original output value of the graph (the collection object)
g->eraseOutput(0);

// Clean up remnant collection node in graph
torch::jit::EliminateDeadCode(g);
LOG_GRAPH("Post removing collection casting operations: " << *g);
}

} // namespace passes
} // namespace lowering
} // namespace core
Expand Down
4 changes: 3 additions & 1 deletion cpp/src/compile_spec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,11 @@ torchtrt::core::CompileSpec init_compile_spec(CompileSpec& external) {
}
}

torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external) {
torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external, bool converting_to_trt_engine) {
torchtrt::core::CompileSpec internal = init_compile_spec(external);

internal.lower_info.converting_to_trt_engine = converting_to_trt_engine;

for (auto p : external.enabled_precisions) {
internal.convert_info.engine_settings.enabled_precisions.insert(toTRTDataType(p));
}
Expand Down
5 changes: 3 additions & 2 deletions cpp/src/torch_tensorrt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace torch_tensorrt {
torch_tensorrt::core::runtime::RTDevice to_internal_rt_device(Device device);
namespace torchscript {
// Defined in compile_spec.cpp
torch_tensorrt::core::CompileSpec to_internal_compile_spec(CompileSpec external);
torch_tensorrt::core::CompileSpec to_internal_compile_spec(CompileSpec external, bool converting_to_trt_engine = false);

bool check_method_operator_support(const torch::jit::script::Module& module, std::string method_name) {
return torch_tensorrt::core::CheckMethodOperatorSupport(module, method_name);
Expand All @@ -23,7 +23,8 @@ std::string convert_method_to_trt_engine(
LOG_DEBUG(get_build_info());
// Want to export a much simpler (non TRT header dependent) API so doing the
// type conversion here
return torch_tensorrt::core::ConvertGraphToTRTEngine(module, method_name, to_internal_compile_spec(info));
return torch_tensorrt::core::ConvertGraphToTRTEngine(
module, method_name, to_internal_compile_spec(info, /*bool converting_to_trt_engine=*/true));
}

torch::jit::script::Module compile(const torch::jit::script::Module& module, CompileSpec info) {
Expand Down
4 changes: 3 additions & 1 deletion py/torch_tensorrt/csrc/tensorrt_classes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,9 +326,11 @@ core::CompileSpec init_compile_spec(CompileSpec external) {
}
}

core::CompileSpec CompileSpec::toInternalCompileSpec() {
core::CompileSpec CompileSpec::toInternalCompileSpec(bool converting_to_trt_engine) {
core::CompileSpec info = init_compile_spec(*this);

info.lower_info.converting_to_trt_engine = converting_to_trt_engine;

for (auto p : enabled_precisions) {
info.convert_info.engine_settings.enabled_precisions.insert(toTRTDataType(p));
}
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/csrc/tensorrt_classes.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ std::string to_str(EngineCapability value);
nvinfer1::EngineCapability toTRTEngineCapability(EngineCapability value);

struct CompileSpec : torch::CustomClassHolder {
core::CompileSpec toInternalCompileSpec();
core::CompileSpec toInternalCompileSpec(bool converting_to_trt_engine = false);
std::string stringify();
void appendInput(const c10::intrusive_ptr<Input>& ir) {
inputs.push_back(*ir);
Expand Down
3 changes: 2 additions & 1 deletion py/torch_tensorrt/csrc/torch_tensorrt_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec& info

py::bytes ConvertGraphToTRTEngine(const torch::jit::Module& mod, const std::string& method_name, CompileSpec& info) {
py::gil_scoped_acquire gil;
auto trt_engine = core::ConvertGraphToTRTEngine(mod, method_name, info.toInternalCompileSpec());
auto trt_engine = core::ConvertGraphToTRTEngine(
mod, method_name, info.toInternalCompileSpec(/*bool converting_to_trt_engine=*/true));
return py::bytes(trt_engine);
}

Expand Down
86 changes: 86 additions & 0 deletions tests/core/lowering/test_remove_unnecessary_casts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -589,3 +589,89 @@ TEST(LoweringPasses, RemoveAtenIntConstTensorValuesAgree) {
// Validate identical graphs after pooling constants and canonicalizing
ASSERT_TRUE((tg->toString() == sg->toString()));
}

TEST(LoweringPasses, RemoveCollectionCastTuple) {
// Ensure the lowering pass transforms the first graph into the second
std::string source_graph = R"IR(
graph(%x.1 : Tensor):
%3 : int = prim::Constant[value=1]()
%2 : int = prim::Constant[value=2]()
%a.1 : Tensor = aten::mul(%x.1, %2)
%b.1 : Tensor = aten::add(%a.1, %2, %3)
%c.1 : Tensor = aten::relu(%b.1)
%d.1 : Tensor = aten::sqrt(%c.1)
%8 : (Tensor, Tensor, Tensor) = prim::TupleConstruct(%c.1, %d.1, %b.1)
return (%8))IR";

std::string target_graph = R"IR(
graph(%x.1 : Tensor):
%3 : int = prim::Constant[value=1]()
%2 : int = prim::Constant[value=2]()
%a.1 : Tensor = aten::mul(%x.1, %2)
%b.1 : Tensor = aten::add(%a.1, %2, %3)
%c.1 : Tensor = aten::relu(%b.1)
%d.1 : Tensor = aten::sqrt(%c.1)
return (%c.1, %d.1, %b.1))IR";

// Ensure the lowering pass transforms the first graph into the second
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
auto sg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(source_graph, sg.get());

torch_tensorrt::core::lowering::passes::RemoveCollectionCast(sg);
torch::jit::ConstantPooling(sg);
sg = torch::jit::Canonicalize(sg, false);

auto tg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(target_graph, tg.get());

torch::jit::ConstantPooling(tg);
tg = torch::jit::Canonicalize(tg, false);

// Validate identical graphs after pooling constants and canonicalizing
ASSERT_TRUE((tg->toString() == sg->toString()));
}

TEST(LoweringPasses, RemoveCollectionCastList) {
// Ensure the lowering pass transforms the first graph into the second
std::string source_graph = R"IR(
graph(%x.1 : Tensor):
%3 : int = prim::Constant[value=1]()
%2 : int = prim::Constant[value=2]()
%a.1 : Tensor = aten::mul(%x.1, %2)
%b.1 : Tensor = aten::add(%a.1, %2, %3)
%c.1 : Tensor = aten::relu(%b.1)
%d.1 : Tensor = aten::sqrt(%c.1)
%8 : (Tensor, Tensor, Tensor) = prim::ListConstruct(%b.1, %c.1, %d.1)
return (%8))IR";

std::string target_graph = R"IR(
graph(%x.1 : Tensor):
%3 : int = prim::Constant[value=1]()
%2 : int = prim::Constant[value=2]()
%a.1 : Tensor = aten::mul(%x.1, %2)
%b.1 : Tensor = aten::add(%a.1, %2, %3)
%c.1 : Tensor = aten::relu(%b.1)
%d.1 : Tensor = aten::sqrt(%c.1)
return (%b.1, %c.1, %d.1))IR";

// Ensure the lowering pass transforms the first graph into the second
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
auto sg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(source_graph, sg.get());

torch_tensorrt::core::lowering::passes::RemoveCollectionCast(sg);
torch::jit::ConstantPooling(sg);
sg = torch::jit::Canonicalize(sg, false);

auto tg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(target_graph, tg.get());

torch::jit::ConstantPooling(tg);
tg = torch::jit::Canonicalize(tg, false);

// Validate identical graphs after pooling constants and canonicalizing
ASSERT_TRUE((tg->toString() == sg->toString()));
}