Skip to content

Commit

Permalink
[inference Zero-Dim]prelu trt converter support zero dim tensor (Padd…
Browse files Browse the repository at this point in the history
…lePaddle#53634)

* prelu op trt converter support zero dim
  • Loading branch information
yuanlehome authored and zhangjun committed May 12, 2023
1 parent 7b811e3 commit 9f8b751
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 173 deletions.
2 changes: 2 additions & 0 deletions paddle/fluid/framework/ir/trt_support_nhwc_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,8 @@ void TrtSupportNHWCPass::ApplyImpl(Graph *graph) const {
}
};
InsertTransposeOp();

AddStatis(transposed_ops.size());
}

} // namespace ir
Expand Down
1 change: 0 additions & 1 deletion paddle/fluid/inference/tensorrt/convert/prelu_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ class PReluOpConverter : public OpConverter {
if (hw_tensor != nullptr) {
shape_tensor = Concat(
std::vector<nvinfer1::ITensor*>{n_tensor, c_tensor, hw_tensor});

} else {
shape_tensor =
Concat(std::vector<nvinfer1::ITensor*>{n_tensor, c_tensor});
Expand Down
28 changes: 14 additions & 14 deletions paddle/fluid/inference/tensorrt/op_teller.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1845,28 +1845,28 @@ struct SimpleOpTypeSetTeller : public Teller {
"the pass.";
return false;
}
auto* var_desc = block->FindVar(desc.Input("Alpha")[0]);
if (!var_desc) {
auto* alpha_var = block->FindVar(desc.Input("Alpha")[0]);
if (!alpha_var) {
VLOG(3) << "Variable Alpha of prelu TRT converter not found.";
return false;
}

auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape();
if (!with_dynamic_shape && x_shape.size() == 1) {
VLOG(3) << "prelu op does not support input's dim is 1 in tensorrt "
"with static shape.";
auto alpha_shape = alpha_var->GetShape();
if (!with_dynamic_shape && alpha_shape.size() == 0) {
VLOG(3) << op_type
<< " op does not support alpha's dim is 0 in tensorrt "
"static shape mode.";
return false;
}

#if IS_TRT_VERSION_LT(7000)
if (!with_dynamic_shape) {
// TODO(inference): fix trt6 static plugin error.
VLOG(3) << "prelu static plugin in trt6 has bug.";
auto x_var_name = desc.Input("X")[0];
auto* x_var = block->FindVar(x_var_name);
const auto x_shape = x_var->GetShape();
if (!with_dynamic_shape && (x_shape.size() == 1 || x_shape.size() == 0)) {
VLOG(3) << op_type
<< " op does not support input's dim is 1 or 0 in tensorrt "
"with static shape.";
return false;
}
#endif
}

if (op_type == "mish") {
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/infermeta/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2328,7 +2328,7 @@ void PReluInferMeta(const MetaTensor& x,
1,
phi::errors::InvalidArgument(
"For mode 'element', rank of input X must be "
"equal or larger than 2. But recevied X's "
"equal or larger than 1. But recevied X's "
"rank: %d",
x_rank));
PADDLE_ENFORCE_EQ(
Expand Down
Loading

0 comments on commit 9f8b751

Please sign in to comment.