-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【BUAA】【Infer Symbolic Shape】add mv, shadow_feed, share_data_ #66956
Changes from 4 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -720,6 +720,39 @@ bool MatmulOpInferSymbolicShape(pir::Operation *op, | |
return true; | ||
} | ||
|
||
bool MvOpInferSymbolicShape(pir::Operation *op, | ||
pir::InferSymbolicShapeContext *infer_context) { | ||
const auto &input_shape = | ||
infer_context->GetShapeOrDataForValue(op->operand_source(0)); | ||
const auto &vec_shape = | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上,取出shape域的值后再用shape后缀 |
||
infer_context->GetShapeOrDataForValue(op->operand_source(1)); | ||
PADDLE_ENFORCE_EQ( | ||
input_shape.shape().size(), | ||
2, | ||
phi::errors::InvalidArgument("The rank of input X should be 2, but is %d", | ||
input_shape.shape().size())); | ||
PADDLE_ENFORCE_EQ(vec_shape.shape().size(), | ||
1, | ||
phi::errors::InvalidArgument( | ||
"The rank of input Vec should be 1, but is %d", | ||
vec_shape.shape().size())); | ||
PADDLE_ENFORCE_EQ(input_shape.shape()[1], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里添加的是 DimExpr 之间的约束,应该使用:infer_context->Addequalcst() |
||
vec_shape.shape()[0], | ||
phi::errors::InvalidArgument( | ||
"X's second dimension is expected to be equal to " | ||
"Vec's first dimension" | ||
"but received X'shape = [%d], Vec's shape = [%d]", | ||
input_shape.shape()[1], | ||
vec_shape.shape()[0])); | ||
|
||
std::vector<symbol::DimExpr> out_shape = {input_shape.shape()[0]}; | ||
infer_context->SetShapeOrDataForValue( | ||
op->result(0), | ||
symbol::ShapeOrDataDimExprs{ | ||
symbol::TensorShapeOrDataDimExprs(out_shape)}); | ||
return true; | ||
} | ||
|
||
// bool PullBoxSparseOpInferSymbolicShape(pir::Operation *op, | ||
// pir::InferSymbolicShapeContext | ||
// *infer_context) { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -133,6 +133,9 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(ScatterNdAdd) | |
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Scatter) | ||
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Scatter_) | ||
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Select) | ||
OP_DECLARE_INFER_SYMBOLIC_SHAPE(ShadowFeed) | ||
OP_DECLARE_INFER_SYMBOLIC_SHAPE(ShareData_) | ||
OP_DECLARE_INFER_SYMBOLIC_SHAPE(ShareData__) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 带下划线的算子inplace名和原算子一致,删掉 ShareData__ There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🎉恭喜触发了隐藏bug,麻烦在周会时跟大家分享 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修复上述问题:7cfe49e |
||
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Sign) | ||
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Sin) | ||
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Sin_) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
命名不规范,这里是shape_or_data 类型,规范命名:x_shape_or_data