-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Infer Symbolic Shape No.124】【BUAA】Add bmm, changed 3 files #67431
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
auto cal_shape_fn = [](const symbol::DimExpr &x, | ||
const symbol::DimExpr &y, | ||
const std::string &error_str) -> symbol::DimExpr { | ||
if (x == -1) { | ||
return y; | ||
} else if (y == -1) { | ||
return x; | ||
} | ||
PADDLE_ENFORCE_EQ(x, y, common::errors::InvalidArgument(error_str, x, y)); | ||
return x; | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DimExpr不可能为-1,这段逻辑是想给两个DimExpr加约束
"in BmmOp, but received Y's shape: [%d].", | ||
y_ndims)); | ||
|
||
auto cal_shape_fn = [](const symbol::DimExpr &x, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
删掉这个lambda表达式,调用cal_shape_fn的地方都使用 addequalcstr
return x; | ||
}; | ||
|
||
cal_shape_fn(x_dims[2], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
修改为 equal constrain
y_dims[1], | ||
"Input(X)'s width must be equal with Input(Y)'s height in " | ||
"BmmOp, but receive X's width: [%d], Y's height: [%d]."); | ||
symbol::DimExpr batch_size = cal_shape_fn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上,batch_size 直接取其中一个Dim就行
const auto &y_shape_or_data = | ||
infer_context->GetShapeOrDataForValue(op->operand_source(1)); | ||
|
||
const std::vector<symbol::DimExpr> &x_dims = x_shape_or_data.shape(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
命名不规范,x_shape
infer_context->GetShapeOrDataForValue(op->operand_source(1)); | ||
|
||
const std::vector<symbol::DimExpr> &x_dims = x_shape_or_data.shape(); | ||
const std::vector<symbol::DimExpr> &y_dims = y_shape_or_data.shape(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
冲突了 |
PR Category
CINN
PR Types
Others
Description
加入bmm, broadcast_tensors:
paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc
paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h
paddle/phi/ops/yaml/ops.yaml
本地测试pir=True且通过