Skip to content

fix: Rewrite constant_pad_nd to use a single slice layer for performance #1970

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 2, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 49 additions & 113 deletions core/conversion/converters/impl/constant_pad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,127 +16,63 @@ auto constant_pad_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns
{"aten::constant_pad_nd(Tensor self, int[] pad, Scalar value=0) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto in = args[0].ITensor();
auto inDims = in->getDimensions();
int64_t inRank = inDims.nbDims;
auto in_dims = in->getDimensions();
int64_t in_rank = in_dims.nbDims;
auto padding = args[1].unwrapToIntList().vec();
int64_t padSize = padding.size();
int64_t pad_size = padding.size();
auto value = args[2].unwrapToScalar().to<float>();
at::Tensor value_tensor = torch::tensor(value, util::TRTDataTypeToScalarType(in->getType()));
auto valueTensor = tensor_to_const(ctx, value_tensor);
TORCHTRT_CHECK(padSize % 2 == 0, "Length of pad must be even but instead it equals " << padSize);

int64_t l_pad = padSize / 2;
TORCHTRT_CHECK(
inRank >= (int64_t)l_pad,
"Length of pad should be no more than twice the number of "
"dimensions of the input. Pad length is "
<< padSize << "while the input has " << inRank << "dimensions.");

// TODO negative padding. When the pad is negative, we need to crop the image.

std::vector<nvinfer1::ITensor*> tensors_vec;
// input: (N, C, D_in, H_in, W_in).
// padding: (padding_left, padding_right, padding_top, padding_bottom, padding_front, padding_back)
// When axis is inRank - 1, making W_out = W_in + padding_left + padding_right.
// When axis is inRank - 2, making H_out = H_in + padding_top + padding_bottom.
// When axis is inRank - 3, making D_out = D_in + padding_front + padding_back.
for (int64_t i = 0; i < l_pad; i++) {
int64_t axis = inRank - (i + 1); // axis = {inRank - 1, inRank - 2, inRank - 3}
int64_t padding_index = i * 2;

if (padding[padding_index] > 0) { // left/top/front padding value
tensors_vec.clear();
if (ctx->input_is_dynamic) {
at::Tensor left_indices = torch::tensor({0}, torch::kInt32);
auto indicesTensor = tensor_to_const(ctx, left_indices);
auto left_gather_layer = ctx->net->addGather(*in, *indicesTensor, axis);
auto left_gather_out = left_gather_layer->getOutput(0);

// fill the left_gather_out with value
auto fill_layer = ctx->net->addFill(nvinfer1::Dims{1, {1}}, nvinfer1::FillOperation::kLINSPACE);
auto shape_gather_out = ctx->net->addShape(*left_gather_out)->getOutput(0);
fill_layer->setInput(0, *shape_gather_out);
fill_layer->setInput(1, *valueTensor);
at::Tensor delta_tensor = torch::zeros(inRank, util::TRTDataTypeToScalarType(in->getType()));
auto deltaTensor = tensor_to_const(ctx, delta_tensor);
fill_layer->setInput(2, *deltaTensor);
auto padTensor = fill_layer->getOutput(0);

for (int i = 0; i < padding[padding_index]; i++) {
tensors_vec.push_back(padTensor);
}
} else {
inDims.d[axis] = padding[padding_index];
auto fill_layer = ctx->net->addFill(inDims, nvinfer1::FillOperation::kLINSPACE);
fill_layer->setInput(1, *valueTensor);
at::Tensor delta_tensor = torch::zeros(inRank, util::TRTDataTypeToScalarType(in->getType()));
auto deltaTensor = tensor_to_const(ctx, delta_tensor);
fill_layer->setInput(2, *deltaTensor);
auto padTensor = fill_layer->getOutput(0);

tensors_vec.push_back(padTensor);
}
auto value_itensor = tensor_to_const(ctx, value_tensor);
TORCHTRT_CHECK(pad_size % 2 == 0, "Length of pad must be even but instead it equals " << pad_size);

std::vector<int64_t> start(in_rank, 0);
std::vector<int64_t> total_padding(in_rank, 0);
std::vector<int64_t> stride(in_rank, 1);

// Padding is stored (left, right) starting from the last dim and working backwards
for (size_t i = 0UL; i < padding.size(); i += 2) {
auto left = padding[i];
TORCHTRT_CHECK(left >= 0, "Unsupported negative pad at index " << i);
auto right = padding[i + 1];
TORCHTRT_CHECK(right >= 0, "Unsupported negative pad at index " << i + 1);
auto idx = in_rank - ((i / 2) + 1);
start[idx] = -left;
total_padding[idx] = left + right;
}

tensors_vec.push_back(in);
auto concat_layer = ctx->net->addConcatenation(tensors_vec.data(), tensors_vec.size());
concat_layer->setAxis(axis);
in = concat_layer->getOutput(0);
inDims = in->getDimensions();
auto size = stride; // placeholder for the dynamic case
if (!ctx->input_is_dynamic) {
size = total_padding;
for (size_t i = 0UL; i < total_padding.size(); ++i) {
size[i] += in_dims.d[i];
}
}

if (padding[padding_index + 1] > 0) { // right/bottom/back padding value
tensors_vec.clear();
tensors_vec.push_back(in);

nvinfer1::ITensor* indicesTensor = NULL;
if (inDims.d[axis] == -1) {
auto shapeTensor = ctx->net->addShape(*in)->getOutput(0);
at::Tensor dimValue = torch::tensor({axis}, torch::kInt32);
auto dimTensor = tensor_to_const(ctx, dimValue);
indicesTensor = ctx->net->addGather(*shapeTensor, *dimTensor, 0)->getOutput(0);
auto oneTensor = tensor_to_const(ctx, torch::tensor({1}, torch::kInt32));
indicesTensor = ctx->net->addElementWise(*indicesTensor, *oneTensor, nvinfer1::ElementWiseOperation::kSUB)
->getOutput(0);
} else {
auto indices = torch::tensor({inDims.d[axis] - 1}, torch::kInt32);
indicesTensor = tensor_to_const(ctx, indices);
}
auto right_gather_layer = ctx->net->addGather(*in, *indicesTensor, axis);
auto right_gather_out = right_gather_layer->getOutput(0);

if (ctx->input_is_dynamic) {
// fill the right_gather_out with value
auto fill_layer = ctx->net->addFill(nvinfer1::Dims{1, {1}}, nvinfer1::FillOperation::kLINSPACE);
auto shape_gather_out = ctx->net->addShape(*right_gather_out)->getOutput(0);
fill_layer->setInput(0, *shape_gather_out);
fill_layer->setInput(1, *valueTensor);
at::Tensor delta_tensor = torch::zeros(inRank, util::TRTDataTypeToScalarType(in->getType()));
auto deltaTensor = tensor_to_const(ctx, delta_tensor);
fill_layer->setInput(2, *deltaTensor);
auto padTensor = fill_layer->getOutput(0);

for (int i = 0; i < padding[padding_index + 1]; i++) {
tensors_vec.push_back(padTensor);
}
} else {
inDims.d[axis] = padding[padding_index + 1];
auto fill_layer = ctx->net->addFill(inDims, nvinfer1::FillOperation::kLINSPACE);
fill_layer->setInput(1, *valueTensor);
at::Tensor delta_tensor = torch::zeros(inRank, util::TRTDataTypeToScalarType(in->getType()));
auto deltaTensor = tensor_to_const(ctx, delta_tensor);
fill_layer->setInput(2, *deltaTensor);
auto padTensor = fill_layer->getOutput(0);

tensors_vec.push_back(padTensor);
}
auto concat_layer = ctx->net->addConcatenation(tensors_vec.data(), tensors_vec.size());
concat_layer->setAxis(axis);
in = concat_layer->getOutput(0);
inDims = in->getDimensions();
}
auto slice_layer = ctx->net->addSlice(
*in,
util::toDims(c10::IntArrayRef(start)),
util::toDims(c10::IntArrayRef(size)),
util::toDims(c10::IntArrayRef(stride)));
TORCHTRT_CHECK(slice_layer, "Unable to create slice layer from node: " << *n);
slice_layer->setName((util::node_info(n) + "_slice").c_str());
slice_layer->setMode(nvinfer1::SliceMode::kFILL);
slice_layer->setInput(4, *value_itensor);

if (ctx->input_is_dynamic) {
// build the size using inetwork layers
auto shape_layer = ctx->net->addShape(*in);
TORCHTRT_CHECK(shape_layer, "Unable to create shape layer from node: " << *n);
shape_layer->setName((util::node_info(n) + "_shape").c_str());
auto total_padding_itensor = tensor_to_const(ctx, torch::tensor(total_padding, torch::kInt32));

auto add_layer = ctx->net->addElementWise(
*shape_layer->getOutput(0), *total_padding_itensor, nvinfer1::ElementWiseOperation::kSUM);
TORCHTRT_CHECK(add_layer, "Unable to create add layer from node: " << *n);
add_layer->setName((util::node_info(n) + "_add").c_str());
slice_layer->setInput(2, *add_layer->getOutput(0));
}

auto out = ctx->AssociateValueAndTensor(n->outputs()[0], in);
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], slice_layer->getOutput(0));
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
return true;
}});
Expand Down