Skip to content

Commit

Permalink
change api to support trt8 in pool3d_op_convert (#36783)
Browse files Browse the repository at this point in the history
* change api for support trt8

* fix:change api
  • Loading branch information
feng_shuai committed Oct 28, 2021
1 parent 7cb7535 commit e2d09c2
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions paddle/fluid/inference/tensorrt/convert/pool3d_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ namespace tensorrt {

inline void DealCeilMode(const nvinfer1::Dims &input_shape,
std::vector<int> ksize, std::vector<int> strides,
std::vector<int> paddings, nvinfer1::DimsCHW *pre_pad,
nvinfer1::DimsCHW *post_pad, int input_dims) {
std::vector<int> paddings, nvinfer1::Dims3 *pre_pad,
nvinfer1::Dims3 *post_pad, int input_dims) {
int input_depth = input_shape.d[input_dims - 3];
int input_height = input_shape.d[input_dims - 2];
int input_width = input_shape.d[input_dims - 1];
Expand All @@ -56,15 +56,15 @@ inline void DealCeilMode(const nvinfer1::Dims &input_shape,
1;

if (floor_d_output_size != ceil_d_output_size) {
post_pad->c() = strides[0] - 1;
post_pad->d[0] = strides[0] - 1;
}

if (floor_h_output_size != ceil_h_output_size) {
post_pad->h() = strides[1] - 1;
post_pad->d[1] = strides[1] - 1;
}

if (floor_w_output_size != ceil_w_output_size) {
post_pad->w() = strides[2] - 1;
post_pad->d[2] = strides[2] - 1;
}
}

Expand Down Expand Up @@ -118,9 +118,9 @@ class Pool3dOpConverter : public OpConverter {
reduce_operation = nvinfer1::ReduceOperation::kAVG;
plugin_pool_type = plugin::Pool3DPlugin::Pool3DType::avg;
}
nvinfer1::DimsCHW nv_ksize(ksize[0], ksize[1], ksize[2]);
nvinfer1::DimsCHW nv_strides(strides[0], strides[1], strides[2]);
nvinfer1::DimsCHW nv_paddings(paddings[0], paddings[1], paddings[2]);
nvinfer1::Dims3 nv_ksize(ksize[0], ksize[1], ksize[2]);
nvinfer1::Dims3 nv_strides(strides[0], strides[1], strides[2]);
nvinfer1::Dims3 nv_paddings(paddings[0], paddings[1], paddings[2]);
nvinfer1::ILayer *layer = nullptr;
if (op_desc.HasAttr("enable_int8")) {
CHECK(op_desc.HasAttr("X_scale"));
Expand Down

1 comment on commit e2d09c2

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.