Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions core/lowering/passes/remove_unnecessary_casts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,13 @@ void RemoveSingleUse0DTensors(std::shared_ptr<torch::jit::Graph>& g) {
user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]);
user->destroy();
break;
case c10::aten::floor_divide:
new_node = g->create(c10::aten::floordiv, user->inputs(), 1);
new_node->insertAfter(user);
new_node->outputs()[0]->setType(c10::IntType::get());
user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]);
user->destroy();
break;
default:
new_node = g->create(user->kind(), user->inputs(), 1);
new_node->insertAfter(user);
Expand Down
133 changes: 133 additions & 0 deletions tests/core/lowering/test_remove_unnecessary_casts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,136 @@ TEST(LoweringPasses, RemoveSingleUse0DTensorsFloatCorrectly) {

ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
}

TEST(LoweringPasses, RemoveSingleUse0DTensorsFloorDivIntCorrectly) {
std::string source_graph = R"IR(
graph(%0: int):
%1: Tensor = prim::Constant[value=[7]]()
%3: Tensor = prim::NumToTensor(%0)
%4: Tensor = aten::floor_divide(%1, %3)
%5: int = aten::Int(%4)
return (%5))IR";
std::string target_graph = R"IR(
graph(%0: int):
%1: int = prim::Constant[value=7]()
%4: int = aten::floordiv(%1, %0)
return (%4))IR";

torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
auto sg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(source_graph, sg.get());

auto first_op = *(sg->block()->nodes().begin());
torch::jit::WithInsertPoint guard(first_op);
torch::jit::Value* r = sg->insertConstant(c10::scalar_to_tensor(7), c10::nullopt, first_op->scope());
r->copyMetadata(first_op->output());
r->setType(c10::TensorType::get());
first_op->output()->replaceAllUsesWith(r);
first_op->destroy();

torch_tensorrt::core::lowering::passes::RemoveSingleUse0DTensors(sg);

auto tg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(target_graph, tg.get());

ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
}

TEST(LoweringPasses, RemoveSingleUse0DTensorsFloorDivFloatCorrectly) {
std::string source_graph = R"IR(
graph(%0: float):
%1: Tensor = prim::Constant[value=[8.]]()
%3: Tensor = prim::NumToTensor(%0)
%4: Tensor = aten::floor_divide(%1, %3)
%5: float = aten::Float(%4)
return (%5))IR";
std::string target_graph = R"IR(
graph(%0: float):
%1: float = prim::Constant[value=8.]()
%4: float = aten::floordiv(%1, %0)
return (%4))IR";

torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
auto sg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(source_graph, sg.get());

auto first_op = *(sg->block()->nodes().begin());
torch::jit::WithInsertPoint guard(first_op);
torch::jit::Value* r = sg->insertConstant(c10::scalar_to_tensor(8.0), c10::nullopt, first_op->scope());
r->copyMetadata(first_op->output());
r->setType(c10::TensorType::get());
first_op->output()->replaceAllUsesWith(r);
first_op->destroy();

torch_tensorrt::core::lowering::passes::RemoveSingleUse0DTensors(sg);

auto tg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(target_graph, tg.get());

ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
}

TEST(LoweringPasses, RemoveSingleUse0DTensorsFloorDivIntValuesAgree) {
std::string source_graph_no_inputs = R"IR(
graph():
%0: int = prim::Constant[value=2]()
%11: int = prim::Constant[value=7]()
%3: Tensor = prim::NumToTensor(%0)
%1: Tensor = prim::NumToTensor(%11)
%4: Tensor = aten::floor_divide(%1, %3)
%50: int = aten::Int(%4)
%5: Tensor = prim::NumToTensor(%50)
return (%5))IR";
std::string target_graph_no_inputs = R"IR(
graph():
%0: int = prim::Constant[value=2]()
%1: int = prim::Constant[value=7]()
%40: int = aten::floordiv(%1, %0)
%4: Tensor = prim::NumToTensor(%40)
return (%4))IR";

auto g_in = std::make_shared<torch::jit::Graph>();
auto g_out = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(source_graph_no_inputs, g_in.get());
torch::jit::parseIR(target_graph_no_inputs, g_out.get());

auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g_in, {});
auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g_out, {});

ASSERT_TRUE(torch_tensorrt::tests::util::exactlyEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor()));
}

TEST(LoweringPasses, RemoveSingleUse0DTensorsFloorDivFloatValuesAgree) {
std::string source_graph_no_inputs = R"IR(
graph():
%0: float = prim::Constant[value=2.]()
%11: float = prim::Constant[value=7.]()
%3: Tensor = prim::NumToTensor(%0)
%1: Tensor = prim::NumToTensor(%11)
%4: Tensor = aten::floor_divide(%1, %3)
%50: float = aten::Float(%4)
%5: Tensor = prim::NumToTensor(%50)
return (%5))IR";
std::string target_graph_no_inputs = R"IR(
graph():
%0: float = prim::Constant[value=2.]()
%1: float = prim::Constant[value=7.]()
%40: float = aten::floordiv(%1, %0)
%4: Tensor = prim::NumToTensor(%40)
return (%4))IR";

auto g_in = std::make_shared<torch::jit::Graph>();
auto g_out = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(source_graph_no_inputs, g_in.get());
torch::jit::parseIR(target_graph_no_inputs, g_out.get());

auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g_in, {});
auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g_out, {});

ASSERT_TRUE(
torch_tensorrt::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6));
}