-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Infer Symbolic Shape No.166][BUAA] rms_norm
#67294
[Infer Symbolic Shape No.166][BUAA] rms_norm
#67294
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
这里 coverage 没过的测试好像都不是我改的内容,麻烦帮忙看一下。 |
timeout了,帮你rerun了 |
// } | ||
bool RmsNormOpInferSymbolicShape( | ||
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { | ||
const symbol::ShapeOrDataDimExprs &x_shape = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
命名不规范,x_shape_or_data
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改。
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { | ||
const symbol::ShapeOrDataDimExprs &x_shape = | ||
infer_context->GetShapeOrDataForValue(op->operand_source(0)); | ||
std::vector<symbol::DimExpr> x_dims = x_shape.shape(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改。
const symbol::ShapeOrDataDimExprs &x_shape = | ||
infer_context->GetShapeOrDataForValue(op->operand_source(0)); | ||
std::vector<symbol::DimExpr> x_dims = x_shape.shape(); | ||
size_t x_dims_size = x_dims.size(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
x_shape_size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改。
|
||
const auto &norm_weight_shape = | ||
infer_context->GetShapeOrDataForValue(op->operand_source(3)); | ||
const std::vector<symbol::DimExpr> norm_weight_dims = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const std::vectorsymbol::DimExpr &,避免拷贝
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
请修改上述问题
收到 |
LGTM |
* Finished rms_norm op * Resolved suggested changes
PR Category
CINN
PR Types
Others
Description
添加
rms_norm
中等难度算子的符号推导接口实现,由于中等难度的算子对我而言难度比较大,所以这个算子我先单独提交的,已通过本地测试。