From 02d502cde50d4c16414daf12547cfe2795b30f83 Mon Sep 17 00:00:00 2001 From: Michael Feliz Date: Thu, 2 Mar 2023 11:59:19 -0800 Subject: [PATCH] Add dynamic conversion path to aten::mul evaluator --- core/conversion/evaluators/aten.cpp | 10 +++++++++ core/conversion/var/Var.cpp | 7 +++++-- tests/cpp/test_dynamic_size.cpp | 32 +++++++++++++++++++++++++++++ 3 files changed, 47 insertions(+), 2 deletions(-) diff --git a/core/conversion/evaluators/aten.cpp b/core/conversion/evaluators/aten.cpp index d78bf9878c..b72320b8da 100644 --- a/core/conversion/evaluators/aten.cpp +++ b/core/conversion/evaluators/aten.cpp @@ -455,6 +455,16 @@ auto aten_registrations TORCHTRT_UNUSED = .evaluator( {c10::Symbol::fromQualString("aten::mul"), [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { + if (!constTypesOnly(args)) { + auto a = args.at(n->input(0)).ITensorOrFreeze(ctx); + auto b = args.at(n->input(1)).ITensorOrFreeze(ctx); + auto mul = + converters::add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, a, b, util::node_info(n)); + TORCHTRT_CHECK(mul, "Unable to create mul layer from node: " << *n); + auto out = ctx->AssociateValueAndTensor(n->outputs()[0], mul->getOutput(0)); + LOG_DEBUG("Output tensor shape: " << out->getDimensions()); + return {}; + } if (args.at(n->input(0)).IValue()->isInt()) { auto a = args.at(n->input(0)).unwrapToInt(); auto b = args.at(n->input(1)).unwrapToInt(); diff --git a/core/conversion/var/Var.cpp b/core/conversion/var/Var.cpp index df8dc2e544..0444663830 100644 --- a/core/conversion/var/Var.cpp +++ b/core/conversion/var/Var.cpp @@ -92,8 +92,9 @@ nvinfer1::ITensor* Var::ITensorOrFreeze(ConversionCtx* ctx) { } TORCHTRT_CHECK( - isITensor() || (isIValue() && (ptr_.ivalue->isTensor() || ptr_.ivalue->isCustomClass())), - "Requested either IValue containing a Tensor, or ITensor, however Var type is " << type_name()); + isITensor() || + (isIValue() && (ptr_.ivalue->isTensor() || ptr_.ivalue->isScalar() || ptr_.ivalue->isCustomClass())), + "Requested either IValue containing a Tensor, Scalar or ITensor, however Var type is " << type_name()); nvinfer1::ITensor* out; @@ -101,6 +102,8 @@ nvinfer1::ITensor* Var::ITensorOrFreeze(ConversionCtx* ctx) { if (ptr_.ivalue->isTensor()) { auto tensor = ptr_.ivalue->toTensor(); out = converters::tensor_to_const(ctx, tensor); + } else if (ptr_.ivalue->isScalar()) { + out = converters::scalar_to_tensor(ctx, ptr_.ivalue->toScalar()); } else { // Split converter generates c10::IValue which hold TensorContainer. auto output_container = ptr_.ivalue->toCustomClass(); diff --git a/tests/cpp/test_dynamic_size.cpp b/tests/cpp/test_dynamic_size.cpp index 202b4f5ddc..afec847b3e 100644 --- a/tests/cpp/test_dynamic_size.cpp +++ b/tests/cpp/test_dynamic_size.cpp @@ -87,5 +87,37 @@ TEST(Converters, ATenResizeGetItemDynShapeCorrectly) { auto trt = trt_results[0].reshape(jit_results[0].sizes()); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} + +TEST(Converters, ATenResizeGetItemDynShapeMulCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : int = prim::Constant[value=0]() + %3 : int = prim::Constant[value=-1]() + %4 : int = prim::Constant[value=2]() + %size.1 : int[] = aten::size(%x.1) + %37 : int = aten::__getitem__(%size.1, %2) + %38 : int = aten::mul(%37, %4) + %39 : int[] = prim::ListConstruct(%38, %3) + %7 : Tensor = aten::reshape(%x.1, %39) + return (%7))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {16, 16, 16}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true); + + auto trt = trt_results[0].reshape(jit_results[0].sizes()); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); } \ No newline at end of file