From 5089edcb9ed1783f005c7a9f7ca6b1e12d8cce38 Mon Sep 17 00:00:00 2001 From: hong <43953930+phlrain@users.noreply.github.com> Date: Mon, 11 Mar 2024 10:26:43 +0800 Subject: [PATCH] add matmul shape constrain (#62567) --- .../paddle_op_infer_sym.cc | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc index 4d3f0222de40c9..ee4f2d406b3a29 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc @@ -844,6 +844,25 @@ bool MatmulOpInferSymbolicShape( shape_analysis->SetShapeOrDataForValue(op->result(0), ShapeOrData{TensorExprs(out_dims)}); + if ((ndims_x == ndims_y) && ndims_x >= 2) { + if (transpose_x_attr == false && transpose_y_attr == false) { + shape_analysis->CreateDimExprBuilder().CstrEq(x_dims[ndims_x - 1], + y_dims[ndims_x - 2]); + } else if (transpose_x_attr == false && transpose_y_attr == true) { + shape_analysis->CreateDimExprBuilder().CstrEq(x_dims[ndims_x - 1], + y_dims[ndims_x - 1]); + } else if (transpose_x_attr == true && transpose_y_attr == false) { + shape_analysis->CreateDimExprBuilder().CstrEq(x_dims[ndims_x - 2], + y_dims[ndims_x - 2]); + } else { + shape_analysis->CreateDimExprBuilder().CstrEq(x_dims[ndims_x - 2], + y_dims[ndims_x - 1]); + } + + for (size_t i = 0; i < ndims_x - 2; ++i) { + shape_analysis->CreateDimExprBuilder().CstrEq(x_dims[i], y_dims[i]); + } + } return true; }