diff --git a/core/conversion/evaluators/aten.cpp b/core/conversion/evaluators/aten.cpp index 5ce3d02978..838175461e 100644 --- a/core/conversion/evaluators/aten.cpp +++ b/core/conversion/evaluators/aten.cpp @@ -223,13 +223,20 @@ auto aten_registrations TORCHTRT_UNUSED = {c10::Symbol::fromQualString("aten::slice"), [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { c10::List list = args.at(n->input(0)).IValue()->to>(); - int64_t start = 0; + int64_t end = 9223372036854775807; auto startIVal = args.at(n->input(1)).IValue(); + auto endIVal = args.at(n->input(2)).IValue(); + if (!startIVal->isNone()) { start = args.at(n->input(1)).unwrapToInt(); } - int64_t end = args.at(n->input(2)).unwrapToInt(); + if (!endIVal->isNone()) { + end = args.at(n->input(2)).unwrapToInt(); + } + if (start > end) { + LOG_DEBUG("The end should be greater than start"); + } int64_t step = args.at(n->input(3)).unwrapToInt(); const int64_t list_size = list.size(); @@ -253,8 +260,9 @@ auto aten_registrations TORCHTRT_UNUSED = return sliced_list; }, - EvalOptions().validSchemas( - {"aten::slice.t(t[] l, int start, int end=9223372036854775807, int step=1) -> (t[])"})}) + EvalOptions().validSchemas({"aten::slice.t(t[] l, int? start=None, int? end=None, int step=1) -> (t[])"})}) + // EvalOptions().validSchemas( + // {"aten::slice.t(t[] l, int start, int end=9223372036854775807, int step=1) -> (t[])"})}) .evaluator( {c10::Symbol::fromQualString("aten::len"), [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { @@ -896,8 +904,14 @@ auto aten_registrations TORCHTRT_UNUSED = auto step = args.at(n->input(2)).unwrapToInt(); return start + idx * step; }, - EvalOptions().validSchemas({"aten::__derive_index(int idx, int start, int step) -> int"})}); - + EvalOptions().validSchemas({"aten::__derive_index(int idx, int start, int step) -> int"})}) + .evaluator( + {c10::Symbol::fromQualString("aten::list"), + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { + c10::List list = args.at(n->input(0)).IValue()->to>(); + return list.copy(); + }, + EvalOptions().validSchemas({"aten::list.t(t[] l) -> (t[])"})}); } // namespace } // namespace evaluators } // namespace conversion diff --git a/tests/core/conversion/evaluators/evaluator_test.bzl b/tests/core/conversion/evaluators/evaluator_test.bzl index cc3a448ad9..04ef178786 100644 --- a/tests/core/conversion/evaluators/evaluator_test.bzl +++ b/tests/core/conversion/evaluators/evaluator_test.bzl @@ -22,5 +22,5 @@ def evaluator_test(name, visibility = None): ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], "//conditions:default": ["@libtorch//:libtorch"], }), - timeout = "short", + timeout = "long", ) diff --git a/tests/core/conversion/evaluators/test_aten_evaluators.cpp b/tests/core/conversion/evaluators/test_aten_evaluators.cpp index 16936ddd6a..c21f7b5461 100644 --- a/tests/core/conversion/evaluators/test_aten_evaluators.cpp +++ b/tests/core/conversion/evaluators/test_aten_evaluators.cpp @@ -931,3 +931,38 @@ TEST(Evaluators, IsNotTrueEvaluatesCorrectly) { ASSERT_TRUE(jit_results[0] == trt_results[0]); } + +TEST(Evaluators, IsAtenSliceEvaluateCorrectly) { + const auto graph = R"IR( + graph(): + %1 : int[] = prim::Constant[value= [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]() + %2 : int = prim::Constant[value = 0]() + %3 : int = prim::Constant[value = 7]() + %4 : int = prim::Constant[value = 2]() + %5 : int[] = aten::slice(%1, %2, %3, %4) + return (%5))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {}); + auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {}); + + ASSERT_TRUE(jit_results[0] == trt_results[0]); +} + +TEST(Evaluators, IsAtenListEvaluateCorrectly) { + const auto graph = R"IR( + graph(): + %1 : int[] = prim::Constant[value= [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]() + %2 : int[] = aten::list(%1) + return (%2))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {}); + auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {}); + + ASSERT_TRUE(jit_results[0] == trt_results[0]); +}