diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc index 4d3f0222de40c9..ee4f2d406b3a29 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc @@ -844,6 +844,25 @@ bool MatmulOpInferSymbolicShape( shape_analysis->SetShapeOrDataForValue(op->result(0), ShapeOrData{TensorExprs(out_dims)}); + if ((ndims_x == ndims_y) && ndims_x >= 2) { + if (transpose_x_attr == false && transpose_y_attr == false) { + shape_analysis->CreateDimExprBuilder().CstrEq(x_dims[ndims_x - 1], + y_dims[ndims_x - 2]); + } else if (transpose_x_attr == false && transpose_y_attr == true) { + shape_analysis->CreateDimExprBuilder().CstrEq(x_dims[ndims_x - 1], + y_dims[ndims_x - 1]); + } else if (transpose_x_attr == true && transpose_y_attr == false) { + shape_analysis->CreateDimExprBuilder().CstrEq(x_dims[ndims_x - 2], + y_dims[ndims_x - 2]); + } else { + shape_analysis->CreateDimExprBuilder().CstrEq(x_dims[ndims_x - 2], + y_dims[ndims_x - 1]); + } + + for (size_t i = 0; i < ndims_x - 2; ++i) { + shape_analysis->CreateDimExprBuilder().CstrEq(x_dims[i], y_dims[i]); + } + } return true; }