-
Notifications
You must be signed in to change notification settings - Fork 351
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(aten::Int): Lowers out aten::Int
This commit adds a pass to lower out aten::[Int/Float/Bool], aten::NumToTensor pairs w.o. exception. We are assumming this is safe as there are similar passes in PyTorch for ONNX lowering however the scope of this rule is intentionally limited to avoid possible cases where it is not safe. Therefore it should not be expected that all aten::Int issues will be solved with this change and the operator itself remains a limitation of TorchTRT Signed-off-by: Naren Dasan <naren@narendasan.com> Signed-off-by: Naren Dasan <narens@nvidia.com>
- Loading branch information
1 parent
ba9f730
commit 908340f
Showing
6 changed files
with
149 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
#include "torch/csrc/jit/passes/subgraph_rewrite.h" | ||
|
||
#include "core/util/prelude.h" | ||
|
||
#include <vector> | ||
|
||
namespace torch_tensorrt { | ||
namespace core { | ||
namespace lowering { | ||
namespace passes { | ||
|
||
|
||
// Presumably this is safe since torch::jit::EraseNumberTypesOnBlock exists which just | ||
// removes prim::TensorToNum, aten::Float, aten::Int and prim::NumToTensor nodes outright | ||
void RemoveUnnecessaryCasts(std::shared_ptr<torch::jit::Graph>& graph) { | ||
std::string int_cast_pattern = R"IR( | ||
graph(%1: int): | ||
%2: Tensor = aten::NumToTensor(%1) | ||
%3: int = aten::Int(%2) | ||
return (%3))IR"; | ||
std::string int_clean_pattern = R"IR( | ||
graph(%1: int): | ||
return (%1))IR"; | ||
|
||
std::string float_cast_pattern = R"IR( | ||
graph(%1: float): | ||
%2: Tensor = aten::NumToTensor(%1) | ||
%3: float = aten::Float(%2) | ||
return (%3))IR"; | ||
std::string float_clean_pattern = R"IR( | ||
graph(%1: float): | ||
return (%1))IR"; | ||
|
||
std::string bool_cast_pattern = R"IR( | ||
graph(%1: bool): | ||
%2: Tensor = aten::NumToTensor(%1) | ||
%3: bool = aten::Bool(%2) | ||
return (%3))IR"; | ||
std::string bool_clean_pattern = R"IR( | ||
graph(%1: bool): | ||
return (%1))IR"; | ||
|
||
torch::jit::SubgraphRewriter int_cast_rewriter; | ||
int_cast_rewriter.RegisterRewritePattern(int_cast_pattern, int_clean_pattern); | ||
int_cast_rewriter.runOnGraph(graph); | ||
|
||
torch::jit::SubgraphRewriter float_cast_rewriter; | ||
float_cast_rewriter.RegisterRewritePattern(float_cast_pattern, float_clean_pattern); | ||
float_cast_rewriter.runOnGraph(graph); | ||
|
||
torch::jit::SubgraphRewriter bool_cast_rewriter; | ||
bool_cast_rewriter.RegisterRewritePattern(bool_cast_pattern, bool_clean_pattern); | ||
bool_cast_rewriter.runOnGraph(graph); | ||
|
||
LOG_GRAPH("After RemoveUnnecessaryCasts: " << *graph); | ||
} | ||
|
||
} // namespace passes | ||
} // namespace lowering | ||
} // namespace core | ||
} // namespace torch_tensorrt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
#include <string> | ||
#include "core/compiler.h" | ||
#include "core/lowering/passes/passes.h" | ||
#include "gtest/gtest.h" | ||
#include "tests/util/util.h" | ||
#include "torch/csrc/jit/ir/irparser.h" | ||
#include "torch/csrc/jit/ir/subgraph_matcher.h" | ||
|
||
TEST(LoweringPasses, RemoveUnnecessaryCastIntCorrectly) { | ||
std::string source_graph = R"IR( | ||
graph(%1: int): | ||
%2: Tensor = aten::NumToTensor(%1) | ||
%3: int = aten::Int(%2) | ||
%4: int = aten::add(%3, %3, %3) | ||
return (%4))IR"; | ||
std::string target_graph = R"IR( | ||
graph(%1: int): | ||
%4: int = aten::add(%1, %1, %1) | ||
return (%4))IR"; | ||
|
||
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( | ||
torch_tensorrt::core::util::logging::LogLevel::kGRAPH); | ||
auto sg = std::make_shared<torch::jit::Graph>(); | ||
torch::jit::parseIR(source_graph, sg.get()); | ||
torch_tensorrt::core::lowering::passes::RemoveContiguous(sg); | ||
|
||
auto tg = std::make_shared<torch::jit::Graph>(); | ||
torch::jit::parseIR(target_graph, tg.get()); | ||
|
||
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty()); | ||
} | ||
|
||
TEST(LoweringPasses, RemoveUnnecessaryCastFloatCorrectly) { | ||
std::string source_graph = R"IR( | ||
graph(%1: float): | ||
%2: Tensor = aten::NumToTensor(%1) | ||
%3: float = aten::Float(%2) | ||
%4: float = aten::add(%3, %3, %3) | ||
return (%3))IR"; | ||
std::string target_graph = R"IR( | ||
graph(%1: float): | ||
%4: float = aten::add(%1, %1, %1) | ||
return (%4))IR"; | ||
|
||
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( | ||
torch_tensorrt::core::util::logging::LogLevel::kGRAPH); | ||
auto sg = std::make_shared<torch::jit::Graph>(); | ||
torch::jit::parseIR(source_graph, sg.get()); | ||
torch_tensorrt::core::lowering::passes::RemoveContiguous(sg); | ||
|
||
auto tg = std::make_shared<torch::jit::Graph>(); | ||
torch::jit::parseIR(target_graph, tg.get()); | ||
|
||
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty()); | ||
} | ||
|
||
TEST(LoweringPasses, RemoveUnnecessaryCastBoolCorrectly) { | ||
std::string source_graph = R"IR( | ||
graph(%1: bool): | ||
%2: Tensor = aten::NumToTensor(%1) | ||
%3: bool = aten::Bool(%2) | ||
%4: bool = aten::__and__(%3, %3) | ||
return (%3))IR"; | ||
std::string target_graph = R"IR( | ||
graph(%1: bool): | ||
%4: bool = aten::__and__(%1, %1) | ||
return (%4))IR"; | ||
|
||
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( | ||
torch_tensorrt::core::util::logging::LogLevel::kGRAPH); | ||
auto sg = std::make_shared<torch::jit::Graph>(); | ||
torch::jit::parseIR(source_graph, sg.get()); | ||
torch_tensorrt::core::lowering::passes::RemoveContiguous(sg); | ||
|
||
auto tg = std::make_shared<torch::jit::Graph>(); | ||
torch::jit::parseIR(target_graph, tg.get()); | ||
|
||
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty()); | ||
} |