1+ #include < string>
2+ #include " gtest/gtest.h"
3+ #include " torch/csrc/jit/ir/irparser.h"
4+ #include " tests/util/util.h"
5+ #include " core/compiler.h"
6+
7+ TEST (Converters, ATenSelectIntTwiceConvertsCorrectly) {
8+ const auto graph = R"IR(
9+ graph(%0 : Tensor):
10+ %2 : int = prim::Constant[value=0]()
11+ %3 : int = prim::Constant[value=3]()
12+ %4 : Tensor = aten::select(%0, %2, %2)
13+ %5 : Tensor = aten::select(%4, %2, %3)
14+ return (%5))IR" ;
15+
16+ auto g = std::make_shared<torch::jit::Graph>();
17+
18+ torch::jit::parseIR (graph, &*g);
19+
20+ auto in = at::randint (1 , 10 , {4 , 4 , 4 }, {at::kCUDA });
21+
22+ auto jit_in = at::clone (in);
23+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
24+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {jit_in});
25+
26+ auto trt_in = at::clone (in);
27+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
28+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_in});
29+
30+ auto trt = trt_results[0 ].reshape (jit_results[0 ].sizes ());
31+
32+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
33+ }
0 commit comments