Skip to content

Commit

Permalink
fix: Address review comments, fix failing tests due to bool mishandling
Browse files Browse the repository at this point in the history
Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>
  • Loading branch information
peri044 committed Aug 10, 2021
1 parent 6844a7f commit 13eef91
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 19 deletions.
24 changes: 9 additions & 15 deletions core/lowering/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ void LowerBlock(torch::jit::Block* b) {
DropUnusedNodes(b);
}

void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, bool disable_cse) {
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
passes::UnpackHardSwish(g);
torch::jit::EliminateRedundantGuards(g);
torch::jit::RemoveListMutation(g);
Expand All @@ -42,7 +42,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, bool disable_cse) {
passes::Conv3DToConvolution(g);
passes::FuseAddMMBranches(g);
passes::RemoveBNDimCheck(g);
if (!disable_cse) {
if (!lower_info.disable_cse) {
torch::jit::EliminateCommonSubexpression(g);
}
// torch::jit::UnrollLoops(g);
Expand Down Expand Up @@ -72,25 +72,19 @@ std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<torch::jit::IValue>> L
auto g = lowered_mod.get_method(method_name).graph();
LOG_GRAPH(*g);

LOG_GRAPH("LibTorch Lowering");
auto graph_and_ivalues = torch::jit::LowerGraph(*g, lowered_mod._ivalue());

// Go through TRTorch Lowering to reformat graph to be conversion friendly
// and also segment for accelerators and executors (TRT-DLA, TRT-GPU, PYT)
// and also segment for accelerators and executors (TRT-DLA, TRT-GPU , PYT)
// unfreeze_module is used to not perform constant folding on weights in the network.
// In quantization aware trained (QAT) models, weights are passed through quantize and
// dequantize nodes which should not be folded. So unfreeze_module is set to True for QAT models.
if (!lower_info.unfreeze_module) {
LOG_GRAPH("TRTorch Graph Lowering");
lowering::LowerGraph(g, false);
}
LOG_GRAPH("TRTorch Graph Lowering");
lowering::LowerGraph(graph_and_ivalues.first, lower_info);

LOG_GRAPH("LibTorch Lowering");
auto graph_and_ivalues = torch::jit::LowerGraph(*g, lowered_mod._ivalue());

if (lower_info.unfreeze_module) {
LOG_GRAPH("TRTorch Graph Lowering");
lowering::LowerGraph(graph_and_ivalues.first, true);
}
// Is this necessary?
lowering::LowerBlock(g->block());
// lowering::LowerBlock(g->block());

return graph_and_ivalues;
}
Expand Down
10 changes: 8 additions & 2 deletions core/lowering/lowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,17 @@ namespace lowering {

struct LowerInfo {
// Internal flag to ensure torch.jit.Module does not get freezed in lowering.cpp. This is required for QAT models.
bool unfreeze_module;
bool unfreeze_module = false;
// CommonSubexpressionElimination removes duplicate expressions which are used frequently in the graph.
// for eg: CSE replaces similar value-d stride nodes of multiple conv layers in a network with a single stride node.
// In QAT models, if two conv layers are consuming same input, there is a QDQ node for each input of the conv.
// Since these QDQ nodes will be identical as they share same input, one of them is eliminated due to CSE lowering
// pass. Disable this in order to not disturb TensorRT's QAT optimizations.
bool disable_cse = false;
};

void LowerBlock(torch::jit::Block* b);
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, bool disable_cse = false);
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info);
torch::jit::Module LowerModule(const torch::jit::script::Module& mod);
std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<torch::jit::IValue>> Lower(
const torch::jit::script::Module& mod,
Expand Down
1 change: 1 addition & 0 deletions cpp/api/src/compile_spec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ core::CompileSpec to_internal_compile_spec(CompileSpec external) {
internal.convert_info.engine_settings.calibrator = external.ptq_calibrator;
} else {
internal.lower_info.unfreeze_module = true;
internal.lower_info.disable_cse = true;
internal.convert_info.engine_settings.calibrator = nullptr;
}
} else {
Expand Down
4 changes: 2 additions & 2 deletions py/trtorch/csrc/tensorrt_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ c10::impl::GenericDict TensorRTBackend::compile(c10::IValue mod_val, c10::impl::
mod = core::lowering::LowerModule(mod);

auto spec = c10::impl::toTypedDict<std::string, at::IValue>(method_compile_spec);
core::CompileSpec cfg({});
lowering::LowerInfo lower_info;
for (auto it = spec.begin(), end = spec.end(); it != end; ++it) {
const auto& method_name = it->key();
auto method = mod.get_method(method_name);
auto graph = method.graph();
core::lowering::LowerGraph(graph, cfg.lower_info);
core::lowering::LowerGraph(graph, lower_info);
}

auto handles = c10::impl::GenericDict(
Expand Down
1 change: 1 addition & 0 deletions py/trtorch/csrc/tensorrt_classes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() {
if (info.convert_info.engine_settings.enabled_precisions.find(nvinfer1::DataType::kINT8) !=
info.convert_info.engine_settings.enabled_precisions.end()) {
info.lower_info.unfreeze_module = true;
info.lower_info.disable_cse = true;
}
}
info.convert_info.engine_settings.sparse_weights = sparse_weights;
Expand Down

0 comments on commit 13eef91

Please sign in to comment.