Skip to content

Commit

Permalink
[Paddle-Inference] fix_ele_convert: IElementWiseLayer can broadcast (#…
Browse files Browse the repository at this point in the history
…37908)

* fix_ele_convert: IElementWiseLayer can broadcast

* fix_ele_convert
  • Loading branch information
Wangzheee authored Dec 9, 2021
1 parent 1911b6f commit f695dc9
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/inference/tensorrt/convert/elementwise_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ class ElementwiseTensorOpConverter : public OpConverter {
}
};

if (CheckDims(dims_x, dims_y)) {
if (dims_x.nbDims == dims_y.nbDims) {
// The two input tensor should have the same dims
VLOG(3) << "Convert a fluid elementwise op to TensorRT IElementWiseLayer";
nvinfer1::IElementWiseLayer* elet_layer =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,21 +317,28 @@ def generate_input(shape):
input1_shape_list = [[4, 32], [2, 4, 32], [4, 2, 4, 32]]
input2_shape1_list = [[32], [4, 32], [2, 4, 32]]
input2_shape2_list = [[4, 1], [2, 4, 1], [4, 2, 4, 1]]
input2_shape3_list = [[32], [2, 1, 1], [4, 2, 1, 1]]
input2_shape4_list = [[32], [4, 32], [4, 1, 1, 1]]
input2_shape3_list = [[32], [2, 1, 1], [4, 2, 1, 32]]
input2_shape4_list = [[32], [4, 32], [4, 1, 4, 32]]
input2_shape5_list = [[32], [2, 1, 32], [4, 1, 1, 32]]
input2_shape6_list = [[1, 32], [1, 32], [1, 1, 1, 32]]
input2_shape_list = [
input2_shape1_list, input2_shape2_list, input2_shape3_list,
input2_shape4_list
input2_shape4_list, input2_shape5_list, input2_shape6_list
]
axis1_list = [[-1], [1, -1], [1, -1]]
axis2_list = [[-1], [0], [0]]
axis3_list = [[-1], [0], [0]]
axis4_list = [[-1], [-1], [0]]
axis_list = [axis1_list, axis2_list, axis3_list, axis4_list]
axis5_list = [[-1, 1], [-1, 0], [-1, 0]]
axis6_list = [[-1, 0], [-1, 1], [-1, 0]]
axis_list = [
axis1_list, axis2_list, axis3_list, axis4_list, axis5_list,
axis6_list
]

for i in range(3):
input1_shape = input1_shape_list[i]
for j in range(4):
for j in range(6):
input2_shape = input2_shape_list[j][i]
for op_type in ["elementwise_add", "elementwise_mul"]:
for axis in axis_list[j][i]:
Expand Down

0 comments on commit f695dc9

Please sign in to comment.