Skip to content

Commit

Permalink
feat(//core/lowering): Adding a new pass to handle new dim checks for
Browse files Browse the repository at this point in the history
batchnorm

Signed-off-by: Naren Dasan <naren@narendasan.com>
Signed-off-by: Naren Dasan <narens@nvidia.com>
  • Loading branch information
narendasan committed Jan 22, 2021
1 parent 6eeba1c commit 3d14cda
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 0 deletions.
1 change: 1 addition & 0 deletions core/lowering/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
passes::Conv2DToConvolution(g);
passes::Conv3DToConvolution(g);
passes::FuseAddMMBranches(g);
passes::RemoveBNDimCheck(g);
torch::jit::EliminateCommonSubexpression(g);
// torch::jit::UnrollLoops(g);
torch::jit::EliminateCommonSubexpression(g);
Expand Down
1 change: 1 addition & 0 deletions core/lowering/passes/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ cc_library(
"exception_elimination.cpp",
"fuse_addmm_branches.cpp",
"fuse_flatten_linear.cpp",
"remove_bn_dim_check.cpp",
"remove_contiguous.cpp",
"remove_dropout.cpp",
"remove_to.cpp",
Expand Down
1 change: 1 addition & 0 deletions core/lowering/passes/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ void Conv3DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
void FuseAddMMBranches(std::shared_ptr<torch::jit::Graph> graph);
void FuseFlattenLinear(std::shared_ptr<torch::jit::Graph>& graph);
void EliminateExceptionOrPassPattern(std::shared_ptr<torch::jit::Graph> graph);
void RemoveBNDimCheck(std::shared_ptr<torch::jit::Graph> graph);
void RemoveContiguous(std::shared_ptr<torch::jit::Graph>& graph);
void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph);
void RemoveTo(std::shared_ptr<torch::jit::Graph> graph);
Expand Down
88 changes: 88 additions & 0 deletions core/lowering/passes/remove_bn_dim_check.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#include "torch/csrc/jit/ir/alias_analysis.h"
#include "torch/csrc/jit/jit_log.h"
#include "torch/csrc/jit/passes/constant_propagation.h"
#include "torch/csrc/jit/passes/dead_code_elimination.h"
#include "torch/csrc/jit/passes/guard_elimination.h"
#include "torch/csrc/jit/passes/peephole.h"
#include "torch/csrc/jit/runtime/graph_executor.h"

#include "core/util/prelude.h"

#include <vector>

namespace trtorch {
namespace core {
namespace lowering {
namespace passes {
namespace {
using namespace torch::jit;
struct BNDimCheckRemoval {
BNDimCheckRemoval(std::shared_ptr<Graph> graph) : graph_(std::move(graph)) {}

void run() {
findBNDimCheckNodes(graph_->block());
torch::jit::EliminateDeadCode(graph_);
LOG_GRAPH("Post aten::addmm branch fusion: " << *graph_);
}

private:
bool isBNDimCheckNodes(Node* n) {
/// Check if this Node hosts a pattern like so:
/// %290 : bool = aten::ne(%289, %9)
/// = prim::If(%290)
/// block0():
/// %291 : str = aten::format(%10, %289)
/// = prim::RaiseException(%291)
/// -> ()
/// block1():
/// -> ()

if (n->blocks().size() != 2) {
return false;
}
auto arm1 = n->blocks()[0];
auto arm2 = n->blocks()[1];
if (arm1->outputs().size() != 0 || arm2->outputs().size() != 0) {
// Make sure that the node doesn't actually produce any Value that are
// used by other nodes
return false;
}

auto arm1_start = arm1->nodes().begin();

if ((*arm1_start)->kind() != c10::Symbol::fromQualString("aten::format") && (*(++arm1_start))->kind() != prim::RaiseException && (*(++arm1_start))->kind() != prim::Return) {
// Make sure that block0 is solely just the exception and the return
return false;
}

if ((*(arm2->nodes().begin()))->kind() != prim::Return) {
// Make sure that block1 is solely the return
return false;
}

return true;
}

void findBNDimCheckNodes(Block* b) {
for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) {
auto n = *it;
if (n->kind() == prim::If && isBNDimCheckNodes(n)) {
LOG_GRAPH("Found that node " << *n << " is an batch norm dim check node (EliminateChecks)" << std::endl);
it.destroyCurrent();
}
}
}

std::shared_ptr<Graph> graph_;
};
} // namespace

void RemoveBNDimCheck(std::shared_ptr<Graph> graph) {
BNDimCheckRemoval bndcr(std::move(graph));
bndcr.run();
}

} // namespace passes
} // namespace lowering
} // namespace core
} // namespace trtorch

0 comments on commit 3d14cda

Please sign in to comment.