@@ -204,4 +204,87 @@ TEST(Converters, ATenSliceNegEndIndexConvertsCorrectly) {
204
204
auto trt = trt_results[0 ].reshape (jit_results[0 ].sizes ());
205
205
206
206
ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
207
- }
207
+ }
208
+
209
+ TEST (Converters, ATenSplitSizesInScriptingConvertsCorrectly) {
210
+ const auto graph = R"IR(
211
+ graph(%x.1 : Tensor):
212
+ %2 : int[] = prim::Constant[value=[1, 2]]()
213
+ %3 : int = prim::Constant[value=1]()
214
+ %4 : Tensor[] = aten::split(%x.1, %2, %3)
215
+ %x1.1 : Tensor, %x2.1 : Tensor = prim::ListUnpack(%4)
216
+ return (%x1.1, %x2.1))IR" ;
217
+
218
+ auto g = std::make_shared<torch::jit::Graph>();
219
+
220
+ torch::jit::parseIR (graph, &*g);
221
+
222
+ auto in = at::randint (1 , 10 , {1 , 3 , 4 , 4 }, {at::kCUDA });
223
+
224
+ auto jit_in = at::clone (in);
225
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
226
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {jit_in});
227
+
228
+ auto trt_in = at::clone (in);
229
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_in});
230
+
231
+ for (int i = 0 ; i < jit_results.size (); i++) {
232
+ auto trt = trt_results[i].reshape (jit_results[i].sizes ());
233
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[i], trt, 2e-6 ));
234
+ }
235
+ }
236
+
237
+ TEST (Converters, ATenSplitSizesinTracingConvertsCorrectly) {
238
+ const auto graph = R"IR(
239
+ graph(%argument_1.1 : Tensor):
240
+ %2 : int[] = prim::Constant[value=[1, 2]]()
241
+ %3 : int = prim::Constant[value=1]()
242
+ %4 : Tensor[] = aten::split_with_sizes(%argument_1.1, %2, %3)
243
+ %5 : Tensor, %6 : Tensor = prim::ListUnpack(%4)
244
+ return (%5, %6))IR" ;
245
+
246
+ auto g = std::make_shared<torch::jit::Graph>();
247
+
248
+ torch::jit::parseIR (graph, &*g);
249
+
250
+ auto in = at::randint (1 , 10 , {1 , 3 , 4 , 4 }, {at::kCUDA });
251
+
252
+ auto jit_in = at::clone (in);
253
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
254
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {jit_in});
255
+
256
+ auto trt_in = at::clone (in);
257
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_in});
258
+
259
+ for (int i = 0 ; i < jit_results.size (); i++) {
260
+ auto trt = trt_results[i].reshape (jit_results[i].sizes ());
261
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[i], trt, 2e-6 ));
262
+ }
263
+ }
264
+
265
+ TEST (Converters, ATenSplitFixedConvertsCorrectly) {
266
+ const auto graph = R"IR(
267
+ graph(%argument_1.1 : Tensor):
268
+ %2 : int = prim::Constant[value=1]()
269
+ %3 : Tensor[] = aten::split(%argument_1.1, %2, %2)
270
+ %4 : Tensor, %5 : Tensor, %6 : Tensor = prim::ListUnpack(%3)
271
+ return (%4, %5, %6))IR" ;
272
+
273
+ auto g = std::make_shared<torch::jit::Graph>();
274
+
275
+ torch::jit::parseIR (graph, &*g);
276
+
277
+ auto in = at::randint (1 , 10 , {1 , 3 , 4 , 4 }, {at::kCUDA });
278
+
279
+ auto jit_in = at::clone (in);
280
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
281
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {jit_in});
282
+
283
+ auto trt_in = at::clone (in);
284
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_in});
285
+
286
+ for (int i = 0 ; i < jit_results.size (); i++) {
287
+ auto trt = trt_results[i].reshape (jit_results[i].sizes ());
288
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[i], trt, 2e-6 ));
289
+ }
290
+ }
0 commit comments