Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【BUAA】【Infer Symbolic Shape】add mv, shadow_feed, share_data_ #66956

Merged
merged 5 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,39 @@ bool MatmulOpInferSymbolicShape(pir::Operation *op,
return true;
}

bool MvOpInferSymbolicShape(pir::Operation *op,
pir::InferSymbolicShapeContext *infer_context) {
const auto &input_shape =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

命名不规范,这里是shape_or_data 类型,规范命名:x_shape_or_data

infer_context->GetShapeOrDataForValue(op->operand_source(0));
const auto &vec_shape =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,取出shape域的值后再用shape后缀

infer_context->GetShapeOrDataForValue(op->operand_source(1));
PADDLE_ENFORCE_EQ(
input_shape.shape().size(),
2,
phi::errors::InvalidArgument("The rank of input X should be 2, but is %d",
input_shape.shape().size()));
PADDLE_ENFORCE_EQ(vec_shape.shape().size(),
1,
phi::errors::InvalidArgument(
"The rank of input Vec should be 1, but is %d",
vec_shape.shape().size()));
PADDLE_ENFORCE_EQ(input_shape.shape()[1],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里添加的是 DimExpr 之间的约束,应该使用:infer_context->Addequalcst()

vec_shape.shape()[0],
phi::errors::InvalidArgument(
"X's second dimension is expected to be equal to "
"Vec's first dimension"
"but received X'shape = [%d], Vec's shape = [%d]",
input_shape.shape()[1],
vec_shape.shape()[0]));

std::vector<symbol::DimExpr> out_shape = {input_shape.shape()[0]};
infer_context->SetShapeOrDataForValue(
op->result(0),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(out_shape)});
return true;
}

// bool PullBoxSparseOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext
// *infer_context) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Kron)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(MatrixRankTol)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(MaskedSelect)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Matmul)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Mv)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Nextafter)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(PullBoxSparse)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(PullGpuPsSparse)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ OP_SAME_OPERANDS_AND_RESULT(ScatterNdAdd)
OP_SAME_OPERANDS_AND_RESULT(Scatter)
OP_SAME_OPERANDS_AND_RESULT(Scatter_)
OP_SAME_OPERANDS_AND_RESULT(Select)
OP_SAME_OPERANDS_AND_RESULT(ShadowFeed)
OP_SAME_OPERANDS_AND_RESULT(ShareData_)
OP_SAME_OPERANDS_AND_RESULT(ShareData__)
OP_SAME_OPERANDS_AND_RESULT(Sign)
OP_SAME_OPERANDS_AND_RESULT(Sin)
OP_SAME_OPERANDS_AND_RESULT(Sin_)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(ScatterNdAdd)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Scatter)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Scatter_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Select)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(ShadowFeed)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(ShareData_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(ShareData__)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

带下划线的算子inplace名和原算子一致,删掉 ShareData__

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎉恭喜触发了隐藏bug,麻烦在周会时跟大家分享

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修复上述问题:7cfe49e

OP_DECLARE_INFER_SYMBOLIC_SHAPE(Sign)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Sin)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Sin_)
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/ops/yaml/inconsistent/static_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,7 @@
kernel:
func: shadow_feed
param: [x]
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : shadow_feed_tensors
args : (Tensor[] x)
Expand All @@ -832,6 +833,7 @@
param: [x]
inplace : (x -> out)
traits : paddle::dialect::ForwardOnlyTrait
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : soft_relu
args : (Tensor x, float threshold = 20.0f)
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3286,6 +3286,7 @@
kernel :
func : mv
backward : mv_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : nadam_
args : (Tensor param, Tensor grad, Tensor learning_rate, Tensor momentum_decay_pow, Tensor beta2_pow, Tensor mu_product, Tensor moment1, Tensor moment2, Tensor master_param, float beta1 = 0.9f, float beta2 = 0.999f, float epsilon = 1.0e-8f, float momentum_decay = 0.004f, bool multi_precision = false)
Expand Down