From c066581d08b0a45f151ef7bdc69d3234fb926223 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Mon, 27 Jul 2020 13:50:07 -0700 Subject: [PATCH] feat(aten::prelu): Implement the multi-channel version of prelu and broadcasting checks Signed-off-byL Naren Dasan Signed-off-by: Naren Dasan --- core/conversion/converters/converters.h | 5 +- .../conversion/converters/impl/activation.cpp | 30 +++++-- core/util/trt_util.cpp | 90 +++++++++++++------ core/util/trt_util.h | 1 + 4 files changed, 93 insertions(+), 33 deletions(-) diff --git a/core/conversion/converters/converters.h b/core/conversion/converters/converters.h index 5e368c4e69..18b1fc376d 100644 --- a/core/conversion/converters/converters.h +++ b/core/conversion/converters/converters.h @@ -55,7 +55,10 @@ struct Weights { inline nvinfer1::ITensor* tensor_to_const(ConversionCtx* ctx, at::Tensor t) { auto t_weights = Weights(ctx, t); - return ctx->net->addConstant(t_weights.shape, t_weights.data)->getOutput(0); + auto const_layer = ctx->net->addConstant(t_weights.shape, t_weights.data); + TRTORCH_CHECK(const_layer, "Unable to freeze tensor"); + const_layer->setName("[Freeze Tensor]"); + return const_layer->getOutput(0); } } // namespace converters diff --git a/core/conversion/converters/impl/activation.cpp b/core/conversion/converters/impl/activation.cpp index d01cec811c..31ac992414 100644 --- a/core/conversion/converters/impl/activation.cpp +++ b/core/conversion/converters/impl/activation.cpp @@ -88,18 +88,34 @@ auto acthardtanh TRTORCH_UNUSED = RegisterNodeConversionPatterns() auto in = args[0].ITensor(); auto slopes = args[1].unwrapToTensor(); - //if (slopes.numel() != 1) { - // auto in_dims = util::toVec(in.getDimensions()); - // auto per_channel_shape = std::vector(in_dims.begin() + 2, in_dims.end()); - // for () - //} + bool to_reshape = false; + auto original_shape = in->getDimensions(); + if (slopes.numel() != 1 && !util::broadcastable(in->getDimensions(), util::toDims(slopes.sizes()), /*multidirectional=*/false)) { + if (util::volume(in->getDimensions()) == util::volume(util::toDims(slopes.sizes()))) { + to_reshape = true; + LOG_DEBUG("Input shape is not broadcastable inserting shuffle layers to reshape to " << util::toDims(slopes.sizes())); + auto in_shuffle = ctx->net->addShuffle(*in); + TRTORCH_CHECK(in_shuffle, "Unable to create resize layer for aten::prelu input"); + in_shuffle->setReshapeDimensions(util::toDims(slopes.sizes())); + in_shuffle->setName(std::string("[Reshape in to " + util::toStr(util::toDims(slopes.sizes())) + " for broadcasting]").c_str()); + in = in_shuffle->getOutput(0); + } + } auto slope_tensor = tensor_to_const(ctx, slopes); - auto new_layer = ctx->net->addParametricReLU(*in, *slope_tensor); new_layer->setName(util::node_info(n).c_str()); - auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0)); + auto out_tensor = new_layer->getOutput(0); + + if (to_reshape) { + auto out_shuffle = ctx->net->addShuffle(*out_tensor); + TRTORCH_CHECK(out_shuffle, "Unable to create resize layer for aten::prelu output"); + out_shuffle->setReshapeDimensions(original_shape); + out_shuffle->setName((std::string("[Reshape back to ") + util::toStr(original_shape) + std::string("]")).c_str()); + out_tensor = out_shuffle->getOutput(0); + } + out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], out_tensor); LOG_DEBUG("Output shape: " << out_tensor->getDimensions()); return true; } diff --git a/core/util/trt_util.cpp b/core/util/trt_util.cpp index 79d76e2951..fbd12a8c53 100644 --- a/core/util/trt_util.cpp +++ b/core/util/trt_util.cpp @@ -6,6 +6,59 @@ namespace trtorch { namespace core { namespace util { +bool broadcastable(nvinfer1::Dims a, nvinfer1::Dims b, bool multidirectional) { + if (a == b) { + return true; + } + + if (multidirectional) { + nvinfer1::Dims a_dims_eq; + nvinfer1::Dims b_dims_eq; + if (a.nbDims > b.nbDims) { + a_dims_eq = a; + b_dims_eq = toDimsPad(toVec(b), a.nbDims); + } else if (a.nbDims < b.nbDims) { + a_dims_eq = toDimsPad(toVec(a), b.nbDims); + b_dims_eq = b; + } else { + a_dims_eq = a; + b_dims_eq = b; + } + + bool broadcastable = true; + for (int i = 0; i < a_dims_eq.nbDims; i++) { + if (b_dims_eq.d[i] == a_dims_eq.d[i] || (b_dims_eq.d[i] == 1 || a_dims_eq.d[i] == 1)) { + continue; + } else { + broadcastable = false; + break; + } + } + return broadcastable; + } else { + nvinfer1::Dims b_dims_eq; + if (a.nbDims > b.nbDims) { + b_dims_eq = toDimsPad(toVec(b), a.nbDims); + } else if (a.nbDims < b.nbDims) { + return false; + } else { + b_dims_eq = b; + } + + bool broadcastable = true; + for (int i = 0; i < a.nbDims; i++) { + if (b_dims_eq.d[i] == a.d[i] || b_dims_eq.d[i] == 1) { + continue; + } else { + broadcastable = false; + break; + } + } + return broadcastable; + } +} + + int64_t volume(const nvinfer1::Dims& d) { return std::accumulate(d.d, d.d + d.nbDims, 1, std::multiplies()); } @@ -16,10 +69,7 @@ nvinfer1::Dims toDimsPad(c10::IntArrayRef l, uint64_t pad_to) { return toDims(l); } - if (pad_to > nvinfer1::Dims::MAX_DIMS) { - //TODO: Handle this with exceptions or whatever - LOG_INTERNAL_ERROR("The list requested to be converted to nvinfer1::Dims exceeds the max number of dimensions for TensorRT"); - } + TRTORCH_CHECK(pad_to <= nvinfer1::Dims::MAX_DIMS, "The list requested to be converted to nvinfer1::Dims exceeds the max number of dimensions for TensorRT"); nvinfer1::Dims dims; dims.nbDims = pad_to; @@ -34,10 +84,8 @@ nvinfer1::Dims toDimsPad(c10::IntArrayRef l, uint64_t pad_to) { } nvinfer1::Dims toDims(c10::IntArrayRef l) { - if (l.size() > nvinfer1::Dims::MAX_DIMS) { - //TODO: Handle this with exceptions or whatever - LOG_INTERNAL_ERROR("The list requested to be converted to nvinfer1::Dims exceeds the max number of dimensions for TensorRT"); - } + TRTORCH_CHECK(l.size() <= nvinfer1::Dims::MAX_DIMS, "The list requested to be converted to nvinfer1::Dims exceeds the max number of dimensions for TensorRT"); + nvinfer1::Dims dims; dims.nbDims = l.size(); for (size_t i = 0; i < l.size(); i++) { @@ -47,10 +95,8 @@ nvinfer1::Dims toDims(c10::IntArrayRef l) { } nvinfer1::Dims toDims(c10::List l) { - if (l.size() > nvinfer1::Dims::MAX_DIMS) { - //TODO: Handle this with exceptions or whatever - LOG_INTERNAL_ERROR("The list requested to be converted to nvinfer1::Dims exceeds the max number of dimensions for TensorRT"); - } + TRTORCH_CHECK(l.size() <= nvinfer1::Dims::MAX_DIMS, "The list requested to be converted to nvinfer1::Dims exceeds the max number of dimensions for TensorRT"); + nvinfer1::Dims dims; dims.nbDims = l.size(); for (size_t i = 0; i < l.size(); i++) { @@ -65,10 +111,8 @@ nvinfer1::Dims toDimsPad(c10::List l, uint64_t pad_to) { return toDims(l); } - if (pad_to > nvinfer1::Dims::MAX_DIMS) { - //TODO: Handle this with exceptions or whatever - LOG_INTERNAL_ERROR("The list requested to be converted to nvinfer1::Dims exceeds the max number of dimensions for TensorRT"); - } + TRTORCH_CHECK(pad_to <= nvinfer1::Dims::MAX_DIMS, "The list requested to be converted to nvinfer1::Dims exceeds the max number of dimensions for TensorRT"); + nvinfer1::Dims dims; dims.nbDims = pad_to; @@ -109,7 +153,7 @@ nvinfer1::Dims unpadDims(const nvinfer1::Dims& d) { nvinfer1::Dims unsqueezeDims(const nvinfer1::Dims& d, int pos) { // acceptable range for pos is [0, d.nbDims] TRTORCH_ASSERT(pos >= 0 && pos <= d.nbDims, "ERROR: Index to unsqueeze is out of bounds."); - + nvinfer1::Dims dims; int i = 0; @@ -148,10 +192,8 @@ std::string toStr(nvinfer1::Dims d) { nvinfer1::DimsHW toDimsHW(c10::List l) { - if (l.size() != 2) { - //TODO: Handle this with exceptions or whatever - LOG_INTERNAL_ERROR("The list requested to be converted to nvinfer1::DimsHW is not 2"); - } + TRTORCH_CHECK(l.size() == 2, "The list requested to be converted to nvinfer1::DimsHW is not 2"); + nvinfer1::DimsHW dims; dims.nbDims = l.size(); for (size_t i = 0; i < l.size(); i++) { @@ -161,10 +203,8 @@ nvinfer1::DimsHW toDimsHW(c10::List l) { } nvinfer1::DimsHW toDimsHW(c10::IntArrayRef l) { - if (l.size() != 2) { - //TODO: Handle this with exceptions or whatever - LOG_INTERNAL_ERROR("The list requested to be converted to nvinfer1::DimsHW is not 2"); - } + TRTORCH_CHECK(l.size() == 2, "The list requested to be converted to nvinfer1::DimsHW is not 2"); + nvinfer1::DimsHW dims; dims.nbDims = l.size(); for (size_t i = 0; i < l.size(); i++) { diff --git a/core/util/trt_util.h b/core/util/trt_util.h index 9593b7017f..693e309f46 100644 --- a/core/util/trt_util.h +++ b/core/util/trt_util.h @@ -77,6 +77,7 @@ namespace util { int64_t volume(const nvinfer1::Dims& d); +bool broadcastable(nvinfer1::Dims a, nvinfer1::Dims b, bool multidirectional=true); nvinfer1::Dims toDimsPad(c10::IntArrayRef l, uint64_t pad_to); nvinfer1::Dims toDimsPad(c10::List l, uint64_t pad_to); nvinfer1::Dims unpadDims(const nvinfer1::Dims& d);