diff --git a/core/conversion/converters/impl/constant_pad.cpp b/core/conversion/converters/impl/constant_pad.cpp index 6d3f1ab609..4191cb1bab 100644 --- a/core/conversion/converters/impl/constant_pad.cpp +++ b/core/conversion/converters/impl/constant_pad.cpp @@ -16,127 +16,63 @@ auto constant_pad_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns {"aten::constant_pad_nd(Tensor self, int[] pad, Scalar value=0) -> (Tensor)", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { auto in = args[0].ITensor(); - auto inDims = in->getDimensions(); - int64_t inRank = inDims.nbDims; + auto in_dims = in->getDimensions(); + int64_t in_rank = in_dims.nbDims; auto padding = args[1].unwrapToIntList().vec(); - int64_t padSize = padding.size(); + int64_t pad_size = padding.size(); auto value = args[2].unwrapToScalar().to(); at::Tensor value_tensor = torch::tensor(value, util::TRTDataTypeToScalarType(in->getType())); - auto valueTensor = tensor_to_const(ctx, value_tensor); - TORCHTRT_CHECK(padSize % 2 == 0, "Length of pad must be even but instead it equals " << padSize); - - int64_t l_pad = padSize / 2; - TORCHTRT_CHECK( - inRank >= (int64_t)l_pad, - "Length of pad should be no more than twice the number of " - "dimensions of the input. Pad length is " - << padSize << "while the input has " << inRank << "dimensions."); - - // TODO negative padding. When the pad is negative, we need to crop the image. - - std::vector tensors_vec; - // input: (N, C, D_in, H_in, W_in). - // padding: (padding_left, padding_right, padding_top, padding_bottom, padding_front, padding_back) - // When axis is inRank - 1, making W_out = W_in + padding_left + padding_right. - // When axis is inRank - 2, making H_out = H_in + padding_top + padding_bottom. - // When axis is inRank - 3, making D_out = D_in + padding_front + padding_back. - for (int64_t i = 0; i < l_pad; i++) { - int64_t axis = inRank - (i + 1); // axis = {inRank - 1, inRank - 2, inRank - 3} - int64_t padding_index = i * 2; - - if (padding[padding_index] > 0) { // left/top/front padding value - tensors_vec.clear(); - if (ctx->input_is_dynamic) { - at::Tensor left_indices = torch::tensor({0}, torch::kInt32); - auto indicesTensor = tensor_to_const(ctx, left_indices); - auto left_gather_layer = ctx->net->addGather(*in, *indicesTensor, axis); - auto left_gather_out = left_gather_layer->getOutput(0); - - // fill the left_gather_out with value - auto fill_layer = ctx->net->addFill(nvinfer1::Dims{1, {1}}, nvinfer1::FillOperation::kLINSPACE); - auto shape_gather_out = ctx->net->addShape(*left_gather_out)->getOutput(0); - fill_layer->setInput(0, *shape_gather_out); - fill_layer->setInput(1, *valueTensor); - at::Tensor delta_tensor = torch::zeros(inRank, util::TRTDataTypeToScalarType(in->getType())); - auto deltaTensor = tensor_to_const(ctx, delta_tensor); - fill_layer->setInput(2, *deltaTensor); - auto padTensor = fill_layer->getOutput(0); - - for (int i = 0; i < padding[padding_index]; i++) { - tensors_vec.push_back(padTensor); - } - } else { - inDims.d[axis] = padding[padding_index]; - auto fill_layer = ctx->net->addFill(inDims, nvinfer1::FillOperation::kLINSPACE); - fill_layer->setInput(1, *valueTensor); - at::Tensor delta_tensor = torch::zeros(inRank, util::TRTDataTypeToScalarType(in->getType())); - auto deltaTensor = tensor_to_const(ctx, delta_tensor); - fill_layer->setInput(2, *deltaTensor); - auto padTensor = fill_layer->getOutput(0); - - tensors_vec.push_back(padTensor); - } + auto value_itensor = tensor_to_const(ctx, value_tensor); + TORCHTRT_CHECK(pad_size % 2 == 0, "Length of pad must be even but instead it equals " << pad_size); + + std::vector start(in_rank, 0); + std::vector total_padding(in_rank, 0); + std::vector stride(in_rank, 1); + + // Padding is stored (left, right) starting from the last dim and working backwards + for (size_t i = 0UL; i < padding.size(); i += 2) { + auto left = padding[i]; + TORCHTRT_CHECK(left >= 0, "Unsupported negative pad at index " << i); + auto right = padding[i + 1]; + TORCHTRT_CHECK(right >= 0, "Unsupported negative pad at index " << i + 1); + auto idx = in_rank - ((i / 2) + 1); + start[idx] = -left; + total_padding[idx] = left + right; + } - tensors_vec.push_back(in); - auto concat_layer = ctx->net->addConcatenation(tensors_vec.data(), tensors_vec.size()); - concat_layer->setAxis(axis); - in = concat_layer->getOutput(0); - inDims = in->getDimensions(); + auto size = stride; // placeholder for the dynamic case + if (!ctx->input_is_dynamic) { + size = total_padding; + for (size_t i = 0UL; i < total_padding.size(); ++i) { + size[i] += in_dims.d[i]; } + } - if (padding[padding_index + 1] > 0) { // right/bottom/back padding value - tensors_vec.clear(); - tensors_vec.push_back(in); - - nvinfer1::ITensor* indicesTensor = NULL; - if (inDims.d[axis] == -1) { - auto shapeTensor = ctx->net->addShape(*in)->getOutput(0); - at::Tensor dimValue = torch::tensor({axis}, torch::kInt32); - auto dimTensor = tensor_to_const(ctx, dimValue); - indicesTensor = ctx->net->addGather(*shapeTensor, *dimTensor, 0)->getOutput(0); - auto oneTensor = tensor_to_const(ctx, torch::tensor({1}, torch::kInt32)); - indicesTensor = ctx->net->addElementWise(*indicesTensor, *oneTensor, nvinfer1::ElementWiseOperation::kSUB) - ->getOutput(0); - } else { - auto indices = torch::tensor({inDims.d[axis] - 1}, torch::kInt32); - indicesTensor = tensor_to_const(ctx, indices); - } - auto right_gather_layer = ctx->net->addGather(*in, *indicesTensor, axis); - auto right_gather_out = right_gather_layer->getOutput(0); - - if (ctx->input_is_dynamic) { - // fill the right_gather_out with value - auto fill_layer = ctx->net->addFill(nvinfer1::Dims{1, {1}}, nvinfer1::FillOperation::kLINSPACE); - auto shape_gather_out = ctx->net->addShape(*right_gather_out)->getOutput(0); - fill_layer->setInput(0, *shape_gather_out); - fill_layer->setInput(1, *valueTensor); - at::Tensor delta_tensor = torch::zeros(inRank, util::TRTDataTypeToScalarType(in->getType())); - auto deltaTensor = tensor_to_const(ctx, delta_tensor); - fill_layer->setInput(2, *deltaTensor); - auto padTensor = fill_layer->getOutput(0); - - for (int i = 0; i < padding[padding_index + 1]; i++) { - tensors_vec.push_back(padTensor); - } - } else { - inDims.d[axis] = padding[padding_index + 1]; - auto fill_layer = ctx->net->addFill(inDims, nvinfer1::FillOperation::kLINSPACE); - fill_layer->setInput(1, *valueTensor); - at::Tensor delta_tensor = torch::zeros(inRank, util::TRTDataTypeToScalarType(in->getType())); - auto deltaTensor = tensor_to_const(ctx, delta_tensor); - fill_layer->setInput(2, *deltaTensor); - auto padTensor = fill_layer->getOutput(0); - - tensors_vec.push_back(padTensor); - } - auto concat_layer = ctx->net->addConcatenation(tensors_vec.data(), tensors_vec.size()); - concat_layer->setAxis(axis); - in = concat_layer->getOutput(0); - inDims = in->getDimensions(); - } + auto slice_layer = ctx->net->addSlice( + *in, + util::toDims(c10::IntArrayRef(start)), + util::toDims(c10::IntArrayRef(size)), + util::toDims(c10::IntArrayRef(stride))); + TORCHTRT_CHECK(slice_layer, "Unable to create slice layer from node: " << *n); + slice_layer->setName((util::node_info(n) + "_slice").c_str()); + slice_layer->setMode(nvinfer1::SliceMode::kFILL); + slice_layer->setInput(4, *value_itensor); + + if (ctx->input_is_dynamic) { + // build the size using inetwork layers + auto shape_layer = ctx->net->addShape(*in); + TORCHTRT_CHECK(shape_layer, "Unable to create shape layer from node: " << *n); + shape_layer->setName((util::node_info(n) + "_shape").c_str()); + auto total_padding_itensor = tensor_to_const(ctx, torch::tensor(total_padding, torch::kInt32)); + + auto add_layer = ctx->net->addElementWise( + *shape_layer->getOutput(0), *total_padding_itensor, nvinfer1::ElementWiseOperation::kSUM); + TORCHTRT_CHECK(add_layer, "Unable to create add layer from node: " << *n); + add_layer->setName((util::node_info(n) + "_add").c_str()); + slice_layer->setInput(2, *add_layer->getOutput(0)); } - auto out = ctx->AssociateValueAndTensor(n->outputs()[0], in); + auto out = ctx->AssociateValueAndTensor(n->outputs()[0], slice_layer->getOutput(0)); LOG_DEBUG("Output tensor shape: " << out->getDimensions()); return true; }});