From 8bc4369a309c728d64b74cbb66359ebfa35b82c9 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Thu, 23 Jul 2020 14:20:46 -0700 Subject: [PATCH] feat(aten::prelu): Basic prelu support Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- .../conversion/converters/impl/activation.cpp | 21 +++++++++ tests/core/converters/test_activation.cpp | 47 +++++++++++++++++++ 2 files changed, 68 insertions(+) diff --git a/core/conversion/converters/impl/activation.cpp b/core/conversion/converters/impl/activation.cpp index e77b3e7a50..d01cec811c 100644 --- a/core/conversion/converters/impl/activation.cpp +++ b/core/conversion/converters/impl/activation.cpp @@ -79,6 +79,27 @@ auto acthardtanh TRTORCH_UNUSED = RegisterNodeConversionPatterns() new_layer->setName(util::node_info(n).c_str()); auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0)); + LOG_DEBUG("Output shape: " << out_tensor->getDimensions()); + return true; + } + }).pattern({ + "aten::prelu(Tensor self, Tensor weight) -> (Tensor)", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto in = args[0].ITensor(); + auto slopes = args[1].unwrapToTensor(); + + //if (slopes.numel() != 1) { + // auto in_dims = util::toVec(in.getDimensions()); + // auto per_channel_shape = std::vector(in_dims.begin() + 2, in_dims.end()); + // for () + //} + + auto slope_tensor = tensor_to_const(ctx, slopes); + + auto new_layer = ctx->net->addParametricReLU(*in, *slope_tensor); + new_layer->setName(util::node_info(n).c_str()); + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0)); + LOG_DEBUG("Output shape: " << out_tensor->getDimensions()); return true; } diff --git a/tests/core/converters/test_activation.cpp b/tests/core/converters/test_activation.cpp index 299420878c..128c3895b1 100644 --- a/tests/core/converters/test_activation.cpp +++ b/tests/core/converters/test_activation.cpp @@ -109,3 +109,50 @@ TEST(Converters, ATenHardTanhCustomRangeConvertsCorrectly) { ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); } +TEST(Converters, ATenPReLUConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, + %1 : Float(1)): + %3 : Tensor = aten::prelu(%0, %1) + return (%3))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, &*g); + + auto in = at::randint(-5, 5, {5}, {at::kCUDA}); + auto slope = at::randint(-5, 5, {1}, {at::kCUDA}); + + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {slope}); + auto jit_results = trtorch::tests::util::RunGraph(g, params, {in}); + + in = at::clone(in); + params = trtorch::core::conversion::get_named_params(g->inputs(), {slope}); + auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in}); + + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + +TEST(Converters, ATenPReLUMultiChannelConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, + %1 : Float(10)): + %3 : Tensor = aten::prelu(%0, %1) + return (%3))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, &*g); + + auto in = at::randint(-5, 5, {1,10, 1, 1}, {at::kCUDA}); + auto slope = at::randint(-5, 5, {10}, {at::kCUDA}); + + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {slope}); + auto jit_results = trtorch::tests::util::RunGraph(g, params, {in}); + + in = at::clone(in); + params = trtorch::core::conversion::get_named_params(g->inputs(), {slope}); + auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in}); + + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + +