Skip to content

Commit

Permalink
add matmul shape constrain (PaddlePaddle#62567)
Browse files Browse the repository at this point in the history
  • Loading branch information
phlrain authored and hitywt committed Mar 11, 2024
1 parent c928bc7 commit 5089edc
Showing 1 changed file with 19 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -844,6 +844,25 @@ bool MatmulOpInferSymbolicShape(
shape_analysis->SetShapeOrDataForValue(op->result(0),
ShapeOrData{TensorExprs(out_dims)});

if ((ndims_x == ndims_y) && ndims_x >= 2) {
if (transpose_x_attr == false && transpose_y_attr == false) {
shape_analysis->CreateDimExprBuilder().CstrEq(x_dims[ndims_x - 1],
y_dims[ndims_x - 2]);
} else if (transpose_x_attr == false && transpose_y_attr == true) {
shape_analysis->CreateDimExprBuilder().CstrEq(x_dims[ndims_x - 1],
y_dims[ndims_x - 1]);
} else if (transpose_x_attr == true && transpose_y_attr == false) {
shape_analysis->CreateDimExprBuilder().CstrEq(x_dims[ndims_x - 2],
y_dims[ndims_x - 2]);
} else {
shape_analysis->CreateDimExprBuilder().CstrEq(x_dims[ndims_x - 2],
y_dims[ndims_x - 1]);
}

for (size_t i = 0; i < ndims_x - 2; ++i) {
shape_analysis->CreateDimExprBuilder().CstrEq(x_dims[i], y_dims[i]);
}
}
return true;
}

Expand Down

0 comments on commit 5089edc

Please sign in to comment.