Skip to content

Commit b57a6dd

Browse files
committed
feat: Add aten::type_as lowering pass
Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>
1 parent 1f2ffc4 commit b57a6dd

File tree

3 files changed

+73
-3
lines changed

3 files changed

+73
-3
lines changed

core/lowering/passes/reduce_to.cpp

+21-3
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,34 @@ void ReduceToOperation(std::shared_ptr<torch::jit::Graph>& graph) {
1212
graph(%x, %device, %dtype, %nb, %copy, %format):
1313
%out : Tensor = aten::to(%x, %device, %dtype, %nb, %copy, %format)
1414
return (%out))IR";
15-
std::string to_general_pattern = R"IR(
15+
std::string to_dtype_pattern = R"IR(
1616
graph(%x, %device, %dtype, %nb, %copy, %format):
1717
%out : Tensor = aten::to(%x, %dtype, %nb, %copy, %format)
1818
return (%out))IR";
1919

20+
std::string to_type_as_pattern = R"IR(
21+
graph(%input, %other):
22+
%out : Tensor = aten::type_as(%input, %other)
23+
return (%out))IR";
24+
25+
std::string to_other_pattern = R"IR(
26+
graph(%input, %other):
27+
%5 : bool = prim::Constant[value=0]()
28+
%6 : None = prim::Constant()
29+
%out : Tensor = aten::to(%input, %other, %5, %5, %6)
30+
return (%out))IR";
31+
2032
// replace aten::to.device with aten::to.dtype
2133
torch::jit::SubgraphRewriter map_aten_device_to_dtype;
22-
map_aten_device_to_dtype.RegisterRewritePattern(to_device_pattern, to_general_pattern);
34+
map_aten_device_to_dtype.RegisterRewritePattern(to_device_pattern, to_dtype_pattern);
2335
map_aten_device_to_dtype.runOnGraph(graph);
24-
LOG_GRAPH("Post lowering of aten::to.device -> " << *graph);
36+
37+
// replace aten::type_as with aten::to.other
38+
torch::jit::SubgraphRewriter map_aten_type_as_to_other;
39+
map_aten_type_as_to_other.RegisterRewritePattern(to_type_as_pattern, to_other_pattern);
40+
map_aten_type_as_to_other.runOnGraph(graph);
41+
42+
LOG_GRAPH("Post lowering of [aten::to.device|aten::type_as] -> " << *graph);
2543
}
2644

2745
} // namespace passes

tests/core/conversion/converters/test_cast.cpp

+29
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include <torch/torch.h>
22
#include <string>
33
#include "core/compiler.h"
4+
#include "core/lowering/passes/passes.h"
45
#include "gtest/gtest.h"
56
#include "tests/util/util.h"
67
#include "torch/csrc/jit/ir/irparser.h"
@@ -133,3 +134,31 @@ TEST(Converters, ATenBoolToINT32TensorConvertsCorrectly) {
133134

134135
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
135136
}
137+
138+
TEST(Converters, ATenTypeAsConvertsCorrectly) {
139+
const auto graph = R"IR(
140+
graph(%0 : Tensor,
141+
%1 : Tensor):
142+
%2 : int = prim::Constant[value=-1]()
143+
%a : int = prim::Constant[value=1]()
144+
%4 : Tensor = aten::add(%0, %2, %a)
145+
%5 : Tensor = aten::gt(%1, %a)
146+
%6 : Tensor = aten::type_as(%4, %5)
147+
return (%6, %5))IR";
148+
149+
auto g = std::make_shared<torch::jit::Graph>();
150+
torch::jit::parseIR(graph, &*g);
151+
152+
auto in1 = at::randint(1, 3, {3, 4, 3}, {at::kCUDA});
153+
auto in2 = at::randint(1, 3, {3, 4, 3}, {at::kCUDA});
154+
// Lower aten::type_as to aten::to.other
155+
trtorch::core::lowering::passes::ReduceToOperation(g);
156+
157+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
158+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in1, in2});
159+
160+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
161+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in1, in2});
162+
163+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
164+
}

tests/core/lowering/test_reduce_to_pass.cpp

+23
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,26 @@ TEST(LoweringPasses, ReduceToCorrectly) {
2626

2727
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
2828
}
29+
30+
TEST(LoweringPasses, ReduceAtenTypeAsCorrectly) {
31+
std::string source_graph = R"IR(
32+
graph(%input, %other):
33+
%out : Tensor = aten::type_as(%input, %other)
34+
return (%out))IR";
35+
std::string target_graph = R"IR(
36+
graph(%input, %other):
37+
%5 : bool = prim::Constant[value=0]()
38+
%6 : None = prim::Constant()
39+
%out : Tensor = aten::to(%input, %other, %5, %5, %6)
40+
return (%out))IR";
41+
42+
trtorch::core::util::logging::get_logger().set_reportable_log_level(trtorch::core::util::logging::LogLevel::kGRAPH);
43+
auto sg = std::make_shared<torch::jit::Graph>();
44+
torch::jit::parseIR(source_graph, &*sg);
45+
trtorch::core::lowering::passes::ReduceToOperation(sg);
46+
47+
auto tg = std::make_shared<torch::jit::Graph>();
48+
torch::jit::parseIR(target_graph, &*tg);
49+
50+
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
51+
}

0 commit comments

Comments
 (0)