diff --git a/core/conversion/converters/impl/stack.cpp b/core/conversion/converters/impl/stack.cpp index 2a4241ecf3..8d8a313cad 100644 --- a/core/conversion/converters/impl/stack.cpp +++ b/core/conversion/converters/impl/stack.cpp @@ -19,12 +19,12 @@ auto stack_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().patt [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { auto in = args[0].IValue()->toListRef(); auto dim = args[1].unwrapToInt(); - if (-1 == dim) { + if (dim < 0) { auto first_in = in[0]; if (first_in.isTensor()) { - dim = first_in.toTensor().ndimension(); + dim = first_in.toTensor().ndimension() + dim + 1; } else { - dim = first_in.toCustomClass()->tensor()->getDimensions().nbDims; + dim = first_in.toCustomClass()->tensor()->getDimensions().nbDims + dim + 1; } } diff --git a/tests/core/conversion/converters/test_stack.cpp b/tests/core/conversion/converters/test_stack.cpp index 6d0fdbec44..4760031bea 100644 --- a/tests/core/conversion/converters/test_stack.cpp +++ b/tests/core/conversion/converters/test_stack.cpp @@ -18,8 +18,7 @@ TEST(Converters, ATenStackPureTensorConvertsCorrectly) { params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2}); - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual( - jit_results[0], trt_results[0].reshape_as(jit_results[0]), THRESHOLD_E5)); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], THRESHOLD_E5)); }; const auto graph = R"IR( graph(%0 : Tensor, @@ -35,9 +34,17 @@ TEST(Converters, ATenStackPureTensorConvertsCorrectly) { %3 : int = prim::Constant[value=-1]() %4 : Tensor = aten::stack(%2, %3) return (%4))IR"; + const auto graph3 = R"IR( + graph(%0 : Tensor, + %1 : Tensor): + %2 : Tensor[] = prim::ListConstruct(%0, %1) + %3 : int = prim::Constant[value=-2]() + %4 : Tensor = aten::stack(%2, %3) + return (%4))IR"; TestATenStackPureTensorConvertsCorrectly(graph); TestATenStackPureTensorConvertsCorrectly(graph2); + TestATenStackPureTensorConvertsCorrectly(graph3); } TEST(Converters, ATenStackPureTensorDynamicConvertsCorrectly) { @@ -89,8 +96,7 @@ TEST(Converters, ATenStackDiffTensorConvertsCorrectly) { params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {in2}); auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1}); - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual( - jit_results[0], trt_results[0].reshape_as(jit_results[0]), THRESHOLD_E5)); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], THRESHOLD_E5)); }; const auto graph = R"IR( graph(%0 : Tensor, @@ -106,6 +112,14 @@ TEST(Converters, ATenStackDiffTensorConvertsCorrectly) { %3 : int = prim::Constant[value=-1]() %4 : Tensor = aten::stack(%2, %3) return (%4))IR"; + const auto graph3 = R"IR( + graph(%0 : Tensor, + %1 : Float(4, 4, 4, strides=[16, 4, 1])): + %2 : Tensor[] = prim::ListConstruct(%0, %1) + %3 : int = prim::Constant[value=-3]() + %4 : Tensor = aten::stack(%2, %3) + return (%4))IR"; TestATenStackDiffTensorConvertsCorrectly(graph); TestATenStackDiffTensorConvertsCorrectly(graph2); + TestATenStackDiffTensorConvertsCorrectly(graph3); }