1+ #include < string>
2+ #include " gtest/gtest.h"
3+ #include " torch/csrc/jit/ir/irparser.h"
4+ #include " tests/util/util.h"
5+ #include " core/compiler.h"
6+
7+ TEST (Converters, ATenCatPureTensorConvertsCorrectly) {
8+ const auto graph = R"IR(
9+ graph(%0 : Tensor,
10+ %1 : Tensor):
11+ %2 : Tensor[] = prim::ListConstruct(%0, %1)
12+ %3 : int = prim::Constant[value=0]()
13+ %4 : Tensor = aten::cat(%2, %3)
14+ return (%4))IR" ;
15+
16+ auto g = std::make_shared<torch::jit::Graph>();
17+ torch::jit::parseIR (graph, &*g);
18+
19+ auto in1 = at::randint (1 , 10 , {5 }, {at::kCUDA });
20+ auto in2 = at::randint (1 , 10 , {5 }, {at::kCUDA });
21+
22+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
23+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {in1, in2});
24+
25+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
26+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in1, in2});
27+
28+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
29+ }
30+
31+ TEST (Converters, ATenCatDiffTensorConvertsCorrectly) {
32+ const auto graph = R"IR(
33+ graph(%0 : Tensor,
34+ %1 : Float(5)):
35+ %2 : Tensor[] = prim::ListConstruct(%0, %1)
36+ %3 : int = prim::Constant[value=0]()
37+ %4 : Tensor = aten::cat(%2, %3)
38+ return (%4))IR" ;
39+
40+ auto g = std::make_shared<torch::jit::Graph>();
41+ torch::jit::parseIR (graph, &*g);
42+
43+ auto in1 = at::randint (1 , 10 , {5 }, {at::kCUDA });
44+ auto in2 = at::randint (1 , 10 , {5 }, {at::kCUDA });
45+
46+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {in2});
47+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {in1});
48+
49+ params = trtorch::core::conversion::get_named_params (g->inputs (), {in2});
50+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in1});
51+
52+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
53+ }
0 commit comments