diff --git a/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc b/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc index 8569dd6347852..7c5af43816c44 100644 --- a/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc @@ -228,7 +228,7 @@ class ElementwiseTensorOpConverter : public OpConverter { } }; - if (CheckDims(dims_x, dims_y)) { + if (dims_x.nbDims == dims_y.nbDims) { // The two input tensor should have the same dims VLOG(3) << "Convert a fluid elementwise op to TensorRT IElementWiseLayer"; nvinfer1::IElementWiseLayer* elet_layer = diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_elementwise.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_elementwise.py index 992e0353837bc..b54b923d3b086 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_elementwise.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_elementwise.py @@ -317,21 +317,28 @@ def generate_input(shape): input1_shape_list = [[4, 32], [2, 4, 32], [4, 2, 4, 32]] input2_shape1_list = [[32], [4, 32], [2, 4, 32]] input2_shape2_list = [[4, 1], [2, 4, 1], [4, 2, 4, 1]] - input2_shape3_list = [[32], [2, 1, 1], [4, 2, 1, 1]] - input2_shape4_list = [[32], [4, 32], [4, 1, 1, 1]] + input2_shape3_list = [[32], [2, 1, 1], [4, 2, 1, 32]] + input2_shape4_list = [[32], [4, 32], [4, 1, 4, 32]] + input2_shape5_list = [[32], [2, 1, 32], [4, 1, 1, 32]] + input2_shape6_list = [[1, 32], [1, 32], [1, 1, 1, 32]] input2_shape_list = [ input2_shape1_list, input2_shape2_list, input2_shape3_list, - input2_shape4_list + input2_shape4_list, input2_shape5_list, input2_shape6_list ] axis1_list = [[-1], [1, -1], [1, -1]] axis2_list = [[-1], [0], [0]] axis3_list = [[-1], [0], [0]] axis4_list = [[-1], [-1], [0]] - axis_list = [axis1_list, axis2_list, axis3_list, axis4_list] + axis5_list = [[-1, 1], [-1, 0], [-1, 0]] + axis6_list = [[-1, 0], [-1, 1], [-1, 0]] + axis_list = [ + axis1_list, axis2_list, axis3_list, axis4_list, axis5_list, + axis6_list + ] for i in range(3): input1_shape = input1_shape_list[i] - for j in range(4): + for j in range(6): input2_shape = input2_shape_list[j][i] for op_type in ["elementwise_add", "elementwise_mul"]: for axis in axis_list[j][i]: